//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// CNTK.core.bs -- core BrainScript library including both general and CNTK-specific definitions
//

##############################################################################
# Layer constructors
#
# A layer constructor is a stateful function that creates and returns an instance
# of a 'learnable function'. A learnable function is a function object that has
# learnable parameters baked into it, which get trained by SGD.
# Calling a layer constructor twice creates two instances with independent parameters.
#
# Learnable function instances can be applied to data or composed directly into
# more complex models. For example:
#   // immediate usage:
#   z = LinearLayer{9000}(h)  # LinearLayer{9000} returns a new function object
#   // composing multiple layers into a model
#   model = Sequential ( DenseLayer{2048, activation=Sigmoid} : LinearLayer {9000} )
#   z = model (features)
#   // applying the same model to two inputs, with shared, jointly updated parameters
#   f = DenseLayer{2048, activation=ReLU}
#   z1 = f (feat1) ; z2 = f (feat2)
# The names are intentionally kept similar to other toolkits.
#
# Note that functions without parameters can be used as layers directly, e.g. Sigmoid.
##############################################################################

# LinearLayer -- create a fully-connected linear projection layer
# Note: outDim may describe a tensor as well.
LinearLayer {outDim, bias = true, init='glorotUniform', initValueScale=1, inputRank=None, mapRank=None, initBias=0} =
{
    # inputRank given: number of zeroes to add to W (mapRank must not be given)
    # mapRank   given: expand W to leave exactly mapRank axes (inputRank must not be given)
    # none      given: expand W to all (same as mapRank=0)
    inputShape =
        if       BS.Constants.IsNone (inputRank) then Inferred  # not given: one Inferred, which will get expanded
        else if !BS.Constants.IsNone (mapRank)   then Fail ("'inputRank' and 'mapRank' cannot be specified at the same time.")
        else Repeat (inputRank, Inferred)
    W = ParameterTensor {_ConcatArrays (outDim, inputShape), init=init, initValueScale=initValueScale}
    b = ParameterTensor {outDim, initValue=initBias}
    outputRank = Length (_AsArray (outDim)) # support outputs with tensor layouts
    inferInputRankToMap =
        if      !BS.Constants.IsNone (inputRank) then -1  # means not specified
        else if  BS.Constants.IsNone (mapRank)   then 0   # default to 'use all input dims'
        else mapRank
    apply (x) =
        if bias
        then Times (W, x, outputRank=outputRank, inferInputRankToMap=inferInputRankToMap) + b
        else Times (W, x, outputRank=outputRank, inferInputRankToMap=inferInputRankToMap)
}.apply

# DenseLayer -- create a fully-connected layer with optional non-linearity
DenseLayer{outDim, bias = true, activation=(x=>x), init='glorotUniform', initValueScale=1, inputRank=None, mapRank=None, initBias=0} = Sequential ( LinearLayer{outDim, bias=bias, init=init, initValueScale=initValueScale, inputRank=inputRank, mapRank=mapRank, initBias=initBias} : activation )

# EmbeddingLayer -- create a linear embedding layer
EmbeddingLayer {outDim,                                   # dimension of embedding
                init='glorotUniform', initValueScale=1,
                embeddingPath = '', transpose = false} =  # load a fixed embedding from a path instead
{
    shape = if transpose then (Inferred : outDim) else (outDim : Inferred)
    E = if embeddingPath == ''
        then ParameterTensor {shape, init=init, initValueScale=initValueScale}  # learnable
        else ParameterTensor {shape, initFromFilePath = embeddingPath, learningRateMultiplier = 0}  # fixed from file
    TimesOp = if transpose then TransposeTimes else Times
    apply (x) = TimesOp (E, x)    # x is expected to be sparse one-hot
}.apply

# ConvolutionalLayer -- create a convolution layer with optional non-linearity
#             [ (shifting dims)  |  (reduction dim)  |  (output dim)  |  (sample dims) ]
#    in     : [ (shifting dims)  |  (reduction dim)  |                |  (sample dims) ]
#    kernel : [ (filter dims)    |  (reduction dim)  |  (output dim)  |                ]
#    out    : [ (shifting dims)] |                   |  (output dim)  |  (sample dims) ]
# BUGBUG: filterShape should be first, so that numOutputChannels can default to 1 (to denote a normal filter), and remaining parameters consistent with Times()
ConvolutionalLayer {numOutputChannels,   # e.g. (1) or BS.Constants.None
                    filterShape,         # e.g. (3:3)
                    bias = true,
                    activation = (x=>x),
                    init = 'glorotUniform',
                    initValueScale = 1,          # TODO: rename to initScale
                    initBias = 0,
                    #reductionRank = 1,          # TODO: support this
                    stride = 1, pad = false,
                    lowerPad = 0, upperPad = 0,
                    dilation = 1,
                    maxTempMemSizeInSamples = 0} =
{
    reductionRank = 1 # TODO: shall become an optional parameter
    outputChannelsShape = _AsArray (numOutputChannels)
    filterRank = Length (filterShape)
    kernelShape = _ConcatArrays (filterShape, Repeat (reductionRank, Inferred)) # kernel := filter plus reductionDims
    W = ParameterTensor{_ConcatArrays (kernelShape, outputChannelsShape), init = init, initValueScale = initValueScale, initFilterRank = filterRank, initOutputRank = -1}  # [ W x H x C x K ]
    b = ParameterTensor(_ConcatArrays (Repeat (Length (filterShape), 1), outputChannelsShape), initValue = initBias)                                                       # [ 1 x 1 x     K ]
    sharing = true    # TODO: support this
    apply (x) = {
        c = Convolution (W, x, filterShape, mapDims = numOutputChannels, stride = stride, sharing = sharing, autoPadding = pad, lowerPad = lowerPad, upperPad = upperPad, dilation = dilation, maxTempMemSizeInSamples = maxTempMemSizeInSamples)
        res = activation (if bias then c + b else c)
    }.res
}.apply

# ConvolutionTransposeLayer -- create a convolution transpose layer with optional non-linearity
ConvolutionTransposeLayer {numOutputChannels,
                           filterShape,         # e.g. (3:3)
                           numInputChannels,
                           bias = true,
                           activation = (x=>x),
                           init = 'glorotUniform',
                           initValueScale = 0.001,
                           initBias = 0,
                           stride = 1, pad = false,
                           lowerPad = 0, upperPad = 0,
                           outputShape = None,
                           dilation = 1,
                           maxTempMemSizeInSamples = 0} =
{
    outputChannelsShape = _AsArray (numOutputChannels)
    kernelShape = _ConcatArrays (filterShape, outputChannelsShape)
    paramShape = _ConcatArrays (kernelShape, _AsArray (numInputChannels))
    W = ParameterTensor{paramShape, init=init, initValueScale=initValueScale, initOnCPUOnly=true}
    b = ParameterTensor(_ConcatArrays (Repeat (Length (filterShape), 1), outputChannelsShape), initValue = initBias)
    sharing = true    # TODO: support this
    apply (x) = {
        c = ConvolutionTranspose (W, x, kernelShape, mapDims=numInputChannels, stride=stride, sharing=sharing, autoPadding=pad, lowerPad=lowerPad, upperPad=upperPad, outputShape = outputShape, dilation = dilation, maxTempMemSizeInSamples = maxTempMemSizeInSamples)
        res = activation (if bias then c + b else c)
    }.res
}.apply

# BilinearUpsamplingLayer -- upsample input activations using transposed convolution with bilinear weights.
BilinearUpsamplingLayer {numChannels, kernelSize, stride} =
{
    kernelShape = kernelSize:kernelSize:numChannels
    paramShape = _ConcatArrays (kernelShape, _AsArray (numChannels))
    W = ParameterTensor{paramShape, init = 'bilinear', initValueScale = 1, initOnCPUOnly = true, learningRateMultiplier = 0}
    sharing = true    # TODO: support this
    apply (x) = {
        c = ConvolutionTranspose (W, x, kernelShape, mapDims = numChannels, stride = stride, sharing = sharing, autoPadding = false, lowerPad = 0, upperPad = 0, outputShape = None, maxTempMemSizeInSamples = 0)
    }.c
}.apply

# MaxPoolingLayer, AveragePoolingLayer -- create a max- or average-pooling layer
_PoolingLayer {poolKind,            # "max" or "average"
               filterShape,         # e.g. (3:3)
               stride = 1, pad = false,
               lowerPad = 0, upperPad = 0, ceilOutDim = false, includePad = false} = # TODO: support this
{
    apply (x) = Pooling (x, poolKind, filterShape, stride = stride, autoPadding = pad, lowerPad = lowerPad, upperPad = upperPad, ceilOutDim = ceilOutDim, includePad = includePad)
}.apply
MaxPoolingLayer {filterShape, stride = 1, pad = false, lowerPad = 0, upperPad = 0, ceilOutDim = false} =
    _PoolingLayer {"max", filterShape, stride = stride, pad = pad, lowerPad = lowerPad, upperPad = upperPad, ceilOutDim = ceilOutDim, includePad = false}
AveragePoolingLayer {filterShape, stride = 1, pad = false, lowerPad = 0, upperPad = 0, ceilOutDim = false, includePad = false} =
    _PoolingLayer {"average", filterShape, stride = stride, pad = pad, lowerPad = lowerPad, upperPad = upperPad, ceilOutDim = ceilOutDim, includePad = includePad}

MaxUnpoolingLayer {filterShape,     # e.g. (3:3)
                   stride = 1,
                   pad = false,
                   lowerPad = 0,
                   upperPad = 0} =
{
    apply (unpoolInput, poolInput) = MaxUnpooling (unpoolInput, poolInput, filterShape, stride = stride, autoPadding = pad, lowerPad = lowerPad, upperPad = upperPad)
}.apply

# RecurrentLSTMLayer -- create an LSTM layer
RecurrentLSTMLayer {outputDim,
                    cellShape = None, # if set then use a projection
                    goBackwards = false,
                    usePeepholes = false,
                    trainInitialState = false,
                    init = 'glorotUniform', initValueScale = 1,
                    enableSelfStabilization = false,
                    allowOptimizedEngine = false} =
if allowOptimizedEngine && BS.Constants.IsNone (cellShape) && !goBackwards && !usePeepholes && !trainInitialState && !enableSelfStabilization then
{
    # use cudnn instead
    W = ParameterTensor {0:0, initFilterRank=0, initOutputRank=-1, init=init, initValueScale=initValueScale}
    apply (x) = OptimizedRNNStack (W, x, outputDim, numLayers=1)
}.apply
else
{
    previousHook =
        if trainInitialState
        then if goBackwards then BS.RNNs.NextHCWithTrainedInitialState{} else BS.RNNs.PreviousHCWithTrainedInitialState{}
        else if goBackwards then BS.RNNs.NextHC else BS.RNNs.PreviousHC
    lstm = BS.RNNs.LSTMBlock {outputDim, cellShape=cellShape, usePeepholes=usePeepholes, enableSelfStabilization=enableSelfStabilization, init=init, initValueScale=initValueScale}
    apply (x) = {
        prevState = previousHook (lstmState) # recurrent memory. E.g. Previous or Next, with or without initial state, beam reordering etc.

        #auxInput = augmentInputHook(x, prevState)   # optionally augment input. Constants.None if none.

        lstmState = lstm (x, prevState)
    }.lstmState.h // that's the value we return
}.apply

# helper to check whether all elements of an array are the same
_AreElementsTheSame (arr, L) =
    if L < 2 then true
    else arr[L-1] == arr[L-2] && _AreElementsTheSame (arr, L-1)

# RecurrentLSTMLayer -- create a whole stack of LSTM layers
RecurrentLSTMLayerStack {layerDims,  # an array of dimensions for each layer (last one is output dimension)
                         cellShapes = None, # if set then use a projection
                         #bidirectional = false,    # TODO: add this
                         usePeepholes = false,
                         init = 'glorotUniform', initValueScale = 1,
                         enableSelfStabilization = false,
                         allowOptimizedEngine = false} =
if allowOptimizedEngine && BS.Constants.IsNone (cellShapes) && !usePeepholes && !enableSelfStabilization && _AreElementsTheSame (layerDims, Length(layerDims)) then
{
    # use cudnn instead
    W = ParameterTensor {0:0, initFilterRank=0, initOutputRank=-1, init=init, initValueScale=initValueScale}
    apply (x) = OptimizedRNNStack (W, x, layerDims[0], numLayers=Length(layerDims))
}.apply
else
    LayerStack {Length(layerDims), i => RecurrentLSTMLayer {
                                           layerDims[i],
                                           cellShape = if BS.Constants.IsNone (cellShapes) then None else cellShapes[i], # if set then use a projection
                                           goBackwards = false,
                                           usePeepholes = usePeepholes,
                                           init = init, initValueScale = initValueScale,
                                           enableSelfStabilization = enableSelfStabilization,
                                           allowOptimizedEngine = allowOptimizedEngine}}

# DelayLayer -- delay input
DelayLayer {T=1, defaultHiddenActivation=0} =
{
    apply (x) =
        if      T > 0 then PastValue   (0, x, timeStep=T,  defaultHiddenActivation=defaultHiddenActivation)
        else if T < 0 then FutureValue (0, x, timeStep=-T, defaultHiddenActivation=defaultHiddenActivation)
        else x
}.apply

# DropoutLayer -- create a drop-out layer
# Not yet supported with this interface; just use Dropout directly.
#DropoutLayer {prob = BS.Constants.None} = if !BS.Constants.IsNone (prob) then Fail ("DropoutLayer: Dropout probability can currently not be specified per-layer.") else
#{
#    apply (x) = Dropout (x)
#}.apply

# BatchNormalizationLayer -- create a batch-normalization layer
BatchNormalizationLayer {spatialRank = 0,  # reduce over these dims. E.g. 2 to reduce over (w,h) in a [W x H x C]-shaped input
                         initialScale = 1,
                         normalizationTimeConstant = 5000, blendTimeConstant = 0,
                         epsilon = 0.00001, useCntkEngine = false, disableRegularization = false} =
{
    #normShape   = _ConcatArrays (Repeat (spatialRank, 1), 0) # spatial dims get a dimension of 1 (broadcasting, while all others are inferred from input)
    normShape   = (0:1)  # TODO: Update this once we support broadcasting-style parameters.
    scale       = ParameterTensor {normShape, initValue = initialScale}
    bias        = ParameterTensor {normShape, initValue = 0}
    runMean     = ParameterTensor {normShape, initValue = 0, learningRateMultiplier = 0} # note: disable learning since these are updated differently
    runVariance = ParameterTensor {normShape, initValue = 0, learningRateMultiplier = 0}
    runCount    = ParameterTensor {(1),       initValue = 0, learningRateMultiplier = 0}
    apply (x)   = BatchNormalization (x, scale, bias, runMean, runVariance, runCount=runCount, spatial=(spatialRank > 0), normalizationTimeConstant=normalizationTimeConstant, blendTimeConstant=blendTimeConstant, epsilon=epsilon, useCntkEngine=useCntkEngine, disableRegularization=disableRegularization)
}.apply

# LayerNormalizationLayer -- create a layer-normalization layer
LayerNormalizationLayer {initScale = 1, initBias = 0} =
{
    gain = ParameterTensor{(1), initValue = initScale}  # TODO: offer Softplus version for protection, as for Stabilizer
    #        f = ConstantTensor (4, (1))
    #        fInv = Reciprocal (f)
    #        gain = fInv .* Log (BS.Constants.One + Exp (f .* ParameterTensor ((1), initValue=0.99537863/* 1/f*ln (e^f-1) */))) # init value is 1
    bias = ParameterTensor{(1), initValue = initBias}

    apply (x) = {
        # normalize w.r.t. actual sample statistics
        mean = ReduceMean (x)
        x0 = x - mean;
        std = Sqrt (ReduceMean (x0 .* x0))
        xHat = ElementDivide (x0, std)

        # denormalize with learned parameters
        val = xHat .* gain + bias
    }.val
}.apply

# StabilizerLayer -- create a scalar stabilizer ["Self-stabilized deep neural network," P. Ghahremani and J. Droppo, ICASSP 2016]
# seems the parameter matters
StabilizerLayer{steepness=4} =
{
    # sharpened Softplus: 1/steepness ln(1+e^{steepness*beta})
    # this behaves linear for weights around 1, yet guarantees positiveness

    param = ParameterTensor ((1), initValue=0.99537863/* 1/steepness*ln (e^steepness-1) for steepness==4 */)

    apply (x) = {
        beta = Log (1 + Exp (steepness .* param)) / steepness
        res  = beta .* x
    }.res
}.apply

# FeatureMVNLayer -- create a corpus-level feature-normalization layer
# This can only be applied to features. Statistics are not shared across invocations,
# which is semantically OK because the values are the same. However, it is not efficient.
FeatureMVNLayer{} = MeanVarNorm

# LogPriorLayer -- create a corpus-level label-prior layer
# This can only be applied to labels. Statistics are not shared across invocations,
# which is semantically OK because the values are the same. However, it is not efficient.
# TODO: document on Wiki
LogPriorLayer{} = LogPrior

# Layers that exist in other tools that we will not have:
# FlattenLayer{}: Not needed since DenseLayer() can handle tensors just fine.
# Activation{}: Not needed since functions can be used directly.

Identity(x) = x # sometimes helpful

None = BS.Constants.None   # for use with some optional parameters; test with IsNone()

Inferred = 0  # denotes a dimension that is to be inferred

##############################################################################
# Composing layers or models into more more complex models
##############################################################################

# Sequential -- composite that applies a sequence of functions onto an input
# Sequential (F:G:H) === F >> G >> H
Sequential (arrayOfFunctions) =
{
    fs = _AsArray (arrayOfFunctions)  # make sure it works with a single function that is not an array
    Apply (x, N) = if N == 0 then x else fs[N-1](Apply (x, N-1))  # we do that recursively
    apply (x) = Apply (x, Length (fs))
    # TODO: change to this once '>>' has been changed to evaluate in the same order (right first)
    #apply = if Length(fs) == 0 then Identity else fs[0] >> Sequential(fs << 1)
}.apply
# Parallel -- composite that applies several functions to the same input and combines the result
# TODO: remove combineFunction; instead create an array of values
Parallel (arrayOfFunctions, combineFunction) =
{
    fs = _AsArray (arrayOfFunctions)
    apply (x) = combineFunction (array[0..Length (fs)-1] (i => fs[i](x)))
}.apply
# MergeBinary -- apply two functions and combine them with a binary function, e.g. Plus
MergeBinary (arrayOfFunctions, combineFunction) =
    if Length (arrayOfFunctions) != 2 then Fail ("Merge() is currently limited to binary functions.") else
    {
        apply (x, y) = combineFunction (arrayOfFunctions[0](x), arrayOfFunctions[1](y))
    }.apply
# LayerStack -- generate a stack of models from a lambda of the form (i => some expression of i)
# e.g. h3 = LayerStack {3, i => MyConvLayer {(32:32:64)[i], (0.0043:1.414:1.414)[i]} } (featNorm)
LayerStack {n, c} = Sequential (array[0..n-1] (c))

##############################################################################
# aliases
##############################################################################

Less                    = CNTK2.Less
Equal                   = CNTK2.Equal
Greater                 = CNTK2.Greater
GreaterEqual            = CNTK2.GreaterEqual
NotEqual                = CNTK2.NotEqual
LessEqual               = CNTK2.LessEqual
Splice                  = CNTK2.Splice
TransposeDimensions     = CNTK2.TransposeDimensions
Times                   = CNTK2.Times
Abs                     = CNTK2.Abs
Ceil                    = CNTK2.Ceil
CrossEntropyWithSoftmax = CNTK2.CrossEntropyWithSoftmax
Dropout                 = CNTK2.Dropout
ElementTimes            = CNTK2.ElementTimes
ElementDivide           = CNTK2.ElementDivide
ClassificationError     = CNTK2.ClassificationError
Exp                     = CNTK2.Exp
Floor                   = CNTK2.Floor
Log                     = CNTK2.Log
Minus                   = CNTK2.Minus
Pass                    = CNTK2.Pass
Plus                    = CNTK2.Plus
RectifiedLinear         = CNTK2.ReLU # deprecated
ReLU                    = CNTK2.ReLU
ReduceSum               = CNTK2.ReduceSum
ReduceLogSum            = CNTK2.ReduceLogSum
ReduceMean              = CNTK2.ReduceMean
ReduceMin               = CNTK2.ReduceMin
ReduceMax               = CNTK2.ReduceMax

Round                   = CNTK2.Round
Sigmoid                 = CNTK2.Sigmoid
StraightThrough         = CNTK2.StraightThrough

##############################################################################
# ComputationNodes
##############################################################################

# helper to cast inputs that are Doubles to ComputationNodes
_AsNodes (inputs, precision=precision) = {
    inArr = _AsArray(inputs)
    arr[i:0..Length(inArr)-1] =
    {
        val = inArr[i]
        res = if      IsDouble(val) then Constant {val, precision=precision}
              else if IsBool(val)   then Constant {if val then 1 else 0, precision=precision}
              else val
    }.res
}.arr

##############################################################################
# "Stable API" with the purpose of staying compatible towards 2.0.
# - Use only tensors as concept. Move away from matrices.
# - Main input goes first
# - Main input is called "_" (becomes either self, or is moved to the end,
#   depending on language binding)
# - tensor shape is called "shape"
# - output shape before input shape
# Operator list is sorted alphabetically within the category.
##############################################################################
CNTK2 = [
    # Currently restricted to operators introduced with Python API in CNTK 1.4.

    // 1. Inputs
    // Changes: dims -> shape
    DynamicAxis(tag='', precision=precision) = new ComputationNode [ operation = 'DynamicAxis' ; /*plus the function args*/  ]
    # TODO: Is it a good idea to default to "feature"?
    Input(shape, dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'InputValue' ; shape = new TensorShape [ /*shape*/ ] ; isImage = false /*plus the function args*/ ]

    // 2. Variables and constants
    // Changes: ParameterTensor -> _Parameter; "dims" -> "shape"
    // Python API:
    // - constant(value, name=None) - value: the tensor constant passed as numpy array. Forwards to parameter() with learningRateMultiplier=0 and initFromLiteral
    // - parameter: Like below, but can take an NDArray on the "init_from_literal" parameter, in which case it is serialized and turned into "initFromLiteral".
    //              (TODO: should be the value parameter instead)
    // TODO: The API for Parameter is different in current 2.0 design, getting a constant as input for the initial values.
    // This needs to be fixed to follow the way the Constant() is exposed in Python
    // Making this an internal node with "_" until we agree on the final interface:
    _Parameter(shape, value = 0, initValue = '', learningRateMultiplier = 1.0, init = ''/*|uniform|fixedValue|gaussian|fromFile|fromLiteral*/, initValueScale = 1, initFilterRank = 0, initOutputRank = 1, initFromFilePath = '', initFromLiteral = '', initOnCPUOnly=true, randomSeed=-1, tag='', precision=precision) = new ComputationNode [ operation = 'LearnableParameter' ; shape = new TensorShape [ /*shape */ ] /*plus the function args*/ ]

    // 3. Shape operations
    // Changes: NewReshape -> Reshape, input -> _, dims -> shape
    Reshape(_, shape, beginAxis=0, endAxis=0, tag='', precision=precision) = new ComputationNode [ operation = 'Reshape' ; inputs = _AsNodes (_, precision=precision) ; shape = new TensorShape [ /*shape*/ ] /*plus the function args*/ ]
    Slice(_, beginIndex, endIndex, axis=1, tag='', precision=precision) =
        if axis < 0 then [ # time axis: specify -1
            beginFlags = if beginIndex > 0 then BS.Boolean.Not (BS.Loop.IsFirstN (beginIndex, _)) else                 BS.Loop.IsLastN  (-beginIndex, _)
            endFlags   = if endIndex   > 0 then                 BS.Loop.IsFirstN (endIndex,   _)  else BS.Boolean.Not (BS.Loop.IsLastN  (-endIndex,   _))
            flags = if      beginIndex == 0 then endFlags
                    else if endIndex   == 0 then beginFlags
                    else                         BS.Boolean.And (beginFlags, endFlags)
            out = if beginIndex == 0 && endIndex == 0
                  then _
                  else BS.Sequences.Gather (flags, _)
        ].out
        else new ComputationNode [ operation = 'Slice' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ] # non-time axis

    Splice (_, axis=1, tag='', precision=precision) =
        if axis < 1 then Fail('Splice does not yet implement splicing the time axis.')
        else [tag1=tag; out = RowStack (_, axis=axis, tag=tag1, precision=precision)].out

    // Swap two axes of a tensor
    TransposeDimensions(_, axis1, axis2, tag='', precision=precision) = new ComputationNode [ operation = 'TransposeDimensions' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]

    // 4. Tensor operations
    // Changes: Matrix -> Tensor. A -> x, B -> y. Data must come on y ("default parameter") hence not using _
    Times(x, y, outputRank=1, inferInputRankToMap=-1, tag='', precision=precision) = new ComputationNode [ operation = 'Times' ; inputs = _AsNodes (x : y, precision=precision) /*plus the function args*/ ]

    // 5. Elementwise operations.
    // Changes: "Matrix" -> "Tensor"; left input -> _; Clip: move input to front. ElementDivide/Times: anotherTensor -> y
    Abs(_, tag='', precision=precision) = new ComputationNode [ operation = 'Abs' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Ceil(_, tag='', precision=precision) = Negate(Floor(Negate(_, precision=precision), precision=precision), tag=tag, precision=precision)
    Clip(_, minValue, maxValue, tag='', precision=precision) = new ComputationNode [ operation = 'Clip' ; inputs = _AsNodes (minValue : maxValue : _, precision=precision) /* plus the function args*/ ]
    # TODO: Make ElementDivide a proper operation
    ElementDivide(_, y, tag='', precision=precision) = ElementTimes(_, Reciprocal(y, precision=precision), tag=tag, precision=precision)
    ElementTimes(_, y, tag='', precision=precision) = new ComputationNode [ operation = 'ElementTimes' ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    Exp(_, tag='', precision=precision) = new ComputationNode [ operation = 'Exp' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Floor(_, tag='', precision=precision) = new ComputationNode [ operation = 'Floor' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Log(_, tag='', precision=precision) = new ComputationNode [ operation = 'Log' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Minus(_, y, tag='', precision=precision) = new ComputationNode [ operation = 'Minus' ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    Plus(_, y, tag='', precision=precision) = new ComputationNode [ operation = 'Plus' ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    Round(_, tag='', precision=precision) = Floor(Plus(_, ConstantTensor(0.5, (1), precision=precision), precision=precision), tag=tag, precision=precision)
    Sqrt(_, tag='', precision=precision) = new ComputationNode [ operation = 'Sqrt' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Square(_, tag='', precision=precision) = ElementTimes(_, _, tag=tag, precision=precision)
    Tanh(_, tag='', precision=precision) = new ComputationNode [ operation = 'Tanh' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    StraightThrough(_, tag='', precision=precision) = new ComputationNode [ operation = 'StraightThrough' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]

    // 6. Reductions
    ReduceSum   (_, axis=None, tag='', precision=precision) = { axis1 = if BS.Constants.IsNone (axis) then 0 else axis ; r = new ComputationNode [ operation = 'ReduceElements' ; inputs = _AsNodes (_, precision=precision) ; axis = axis1 ; reductionOp = "Sum"    /*plus the function args*/ ]}.r
    ReduceLogSum(_, axis=None, tag='', precision=precision) = { axis1 = if BS.Constants.IsNone (axis) then 0 else axis ; r = new ComputationNode [ operation = 'ReduceElements' ; inputs = _AsNodes (_, precision=precision) ; axis = axis1 ; reductionOp = "LogSum" /*plus the function args*/ ]}.r
    ReduceMean  (_, axis=None, tag='', precision=precision) = { axis1 = if BS.Constants.IsNone (axis) then 0 else axis ; r = new ComputationNode [ operation = 'ReduceElements' ; inputs = _AsNodes (_, precision=precision) ; axis = axis1 ; reductionOp = "Mean"   /*plus the function args*/ ]}.r
    ReduceMin   (_, axis=None, tag='', precision=precision) = { axis1 = if BS.Constants.IsNone (axis) then 0 else axis ; r = new ComputationNode [ operation = 'ReduceElements' ; inputs = _AsNodes (_, precision=precision) ; axis = axis1 ; reductionOp = "Min"    /*plus the function args*/ ]}.r
    ReduceMax   (_, axis=None, tag='', precision=precision) = { axis1 = if BS.Constants.IsNone (axis) then 0 else axis ; r = new ComputationNode [ operation = 'ReduceElements' ; inputs = _AsNodes (_, precision=precision) ; axis = axis1 ; reductionOp = "Max"    /*plus the function args*/ ]}.r

    // 7. Control flow (if, composite etc.)
    // None so far

    // 8. Boolean operations
    // None so far

    // 9. Recurrent operations
    // Changes: input first; input -> _
    FutureValue(_, shape, timeStep = 1, defaultHiddenActivation = 0.1, tag='', precision=precision) = new ComputationNode [ operation = 'FutureValue' ; inputs = _AsNodes (_, precision=precision) ; shape = new TensorShape [ /*shape*/ ] /*plus the function args*/ ]
    PastValue(_, shape, timeStep = 1, defaultHiddenActivation = 0.1, tag='', precision=precision) = new ComputationNode [ operation = 'PastValue' ; inputs = _AsNodes (_, precision=precision) ; shape = new TensorShape [ /*shape*/ ] /*plus the function args*/ ]

    // 10. NN-specific operations
    // Changes: input -> _, RectifiedLinear -> ReLU
    ReLU(_, tag='', precision=precision) = new ComputationNode [ operation = 'RectifiedLinear' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Relu = ReLU // [Use Relu to arrive at relu() in snake_case]
    Sigmoid(_, tag='', precision=precision) = new ComputationNode [ operation = 'Sigmoid' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Softmax(_, tag='', precision=precision) = new ComputationNode [ operation = 'Softmax' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Dropout(_, tag='', precision=precision) = new ComputationNode [ operation = 'Dropout' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]

    // 11. Criterion nodes
    // No changes here - we said the default input would be the label sequence here, against which the
    // empirical sequence is compared to. Keeping this for now.
    CrossEntropyWithSoftmax(labelSequence, outProbVectorSequence, axis=0, tag='', precision=precision) =
        if axis==0 then new ComputationNode [ operation = 'CrossEntropyWithSoftmax' ; inputs = _AsNodes (labelSequence : outProbVectorSequence, precision=precision) /*plus the function args*/ ]
        else [ tag1 = tag; out = Minus (ReduceLogSum (outProbVectorSequence, axis=axis, precision=precision), ReduceSum (ElementTimes(labelSequence, outProbVectorSequence, precision=precision), axis=axis, precision=precision), tag=tag1, precision=precision) ].out
    # Classification error along a specific axis: account only for missed labels, i.e.
    # strictly check whether at the one '1' location in labels we find a value equal to the max
    ClassificationError(labelSequence, outVectorSequence, topN=1, axis=0, tag='', precision=precision) =
        if axis==0 then new ComputationNode [ operation = 'ClassificationError' ; inputs = _AsNodes (if topN == 1 then (labelSequence : outVectorSequence) else  (labelSequence : outVectorSequence : Constant (topN, precision=precision)), precision=precision) /*plus the function args*/ ]
        else if topN != 1 then Fail ("ClassificationError() along a specific axis does not support topN.")
        else {
            axMax     = ReduceMax (outVectorSequence, axis=axis, precision=precision)    # max value along competition axis
            pred      = outVectorSequence == axMax                                       # 1 for all values that are max
            wrongPred = labelSequence != pred                                            # look up all wrong predictions {label index}
            axErr     = ReduceSum (wrongPred, axis=axis, precision=precision)            # sum up wrong predictions  along competition axis
            capErr    = axErr >= 1                                                       # only count maximally one error per prediction
            err       = ReduceMean (capErr, tag=tag, precision=precision)                # average into a single number per sample
        }.err
    ErrorPrediction = ClassificationError  # legacy
    # TODO: replace with this (need to deal with topN thing):
    # (_new will be removed once the change is made)
    CrossEntropyWithSoftmax_new (L, z, tag='', precision=precision) = Minus (ReduceLogSum (z, precision=precision), TransposeTimes (L,          z, precision=precision),  tag=tag)
    ClassificationError_new (L, z, tag='', precision=precision)     = Minus (BS.Constants.One, TransposeTimes (L, Hardmax (z, precision=precision), precision=precision), tag=tag, precision=precision)

    // 12. Comparison nodes
    Less(_, y, tag='', precision=precision)         = new ComputationNode [ operation = 'Less'         ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    Equal(_, y, tag='', precision=precision)        = new ComputationNode [ operation = 'Equal'        ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    Greater(_, y, tag='', precision=precision)      = new ComputationNode [ operation = 'Greater'      ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    GreaterEqual(_, y, tag='', precision=precision) = new ComputationNode [ operation = 'GreaterEqual' ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    NotEqual(_, y, tag='', precision=precision)     = new ComputationNode [ operation = 'NotEqual'     ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]
    LessEqual(_, y, tag='', precision=precision)    = new ComputationNode [ operation = 'LessEqual'    ; inputs = _AsNodes (_ : y, precision=precision) /*plus the function args*/ ]

    // 13. Others
    Pass(_, tag='', precision=precision) = new ComputationNode [ operation = 'Pass' ; inputs = _AsNodes (_, precision=precision) /*plus the function args*/ ]
    Identity = Pass

    // The value of GetRandomSample(weights /* vector of length nClasses */, numSamples, sampleWithReplacement) randomly samples numSamples using the specified sampling weights.
    // The result is a sparse matrix of num samples one-hot vectors as columns.
    GetRandomSample(_ ,numSamples, sampleWithReplacement, tag='', precision=precision) = new ComputationNode [
                                                                                        operation = 'RandomSample' ;
                                                                                        sizeOfSampledSet = numSamples;
                                                                                        allowDuplicates = sampleWithReplacement;
                                                                                        inputs = _ /*plus the function args*/ ]

    // The value of GetInclusion(weights /* vector of length nClasses */, numSamples, sampleWithReplacement) has to be seen in cojuction to GetRandomSample(...).
    // While GetRandomSample(...) creates a set of samples, GetInclusion(...) tells how often each class is expected to occur in the sampled sets.
    // For sampling with replacment the relation to the sampling weights is trivial but not for sampling without replacment.
    GetInclusionFrequency(_ ,numSamples, sampleWithReplacement, tag='', precision=precision) = new ComputationNode [
                                                                                        operation = 'RandomSampleInclusionFrequency' ;
                                                                                        sizeOfSampledSet = numSamples;
                                                                                        allowDuplicates = sampleWithReplacement;
                                                                                        inputs = _ /*plus the function args*/ ]
]

# Parameter{} can do several forms of initialization.
#  - initValue=scalar, value=array --> initialize from this value  --array form not implemented yet
#  - initFromFilePath="..." --> read from a data file
#  - init="uniform|gaussian" (random init scaled by initValueScale).
#  - init="zero"
# deprecated:
#  - initFromLiteral="..." (deprecated) --> parse a string literal (obsolete with value=array form)
#  - init="fixedValue", value from 'value'
# Warning: Current config will behave unexpected if user mistypes 'initValue' as 'value' (which will be ignored, defaulting to "uniform" init)
Parameter {outputDim, inputDim, learningRateMultiplier = 1.0, init = ''/*|uniform|fixedValue|gaussian|fromFile|fromLiteral*/, initValueScale = 1, value = 0/*deprecated*/, initValue = '', initFromFilePath = '', initFromLiteral = ''/*deprecated*/, initOnCPUOnly=true, randomSeed=-1, tag='', precision=precision} = new ComputationNode [ operation = 'LearnableParameter' ; initFilterRank = 0 ; initOutputRank = 1 ; shape = new TensorShape [ dims = (outputDim : inputDim) ] /*plus the function args*/ ]

LearnableParameter = Parameter  // deprecated

# TODO: make Parameter take tensor dims?
ParameterTensor {dims, learningRateMultiplier = 1.0, init = ''/*|uniform|fixedValue|gaussian|fromFile|fromLiteral*/, initValueScale = 1, value = 0, initValue = '', initFilterRank = 0, initOutputRank = 1, initFromFilePath = '', initFromLiteral = '', initOnCPUOnly=true, randomSeed=-1, tag='', precision=precision} = new ComputationNode [ operation = 'LearnableParameter' ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]
ConstantFromString(literal, tag='', precision=precision) = ParameterTensor((0)/*dim, will be inferred*/, initFromLiteral = literal, learningRateMultiplier = 0.0, precision=precision)
# TODO: Deprecate ConstantFromString() in favor of Constant(array expression)
DynamicAxis(tag='', precision=precision) = new ComputationNode [ operation = 'DynamicAxis' ; /*plus the function args*/  ]
Input(dims, dynamicAxis='', sparse=false, tag='feature', precision=precision) =
     if sparse then SparseInput(dims, dynamicAxis=dynamicAxis, tag=tag, precision=precision)
     else new ComputationNode [ operation = 'InputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]
# TODO: change from dynamicAxis by name to dynamicAxis being an actual object
# the following variants of Input() are deprecated
SparseInput(dims, dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'SparseInputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]
ImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'InputValue' ; isImage = true /*plus the function args*/ ]
SparseImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'SparseInputValue' ; isImage = true /*plus the function args*/ ]
EnvironmentInput(propertyName, tag='', precision=precision) = new ComputationNode [ operation = 'EnvironmentInput' /*plus the function args*/ ]
# TODO: make 'dims' the first parameter, think ConstantTensor<dims> (val)
ConstantTensor(val, dims, tag='', precision=precision) = ParameterTensor(dims, learningRateMultiplier = 0, initValue = val, precision=precision)
Constant(val, rows = 1, cols = 1, tag='', precision=precision) = Parameter(rows, cols, learningRateMultiplier = 0, initValue = val, precision=precision)
# in PastValue/FutureValue, initialState, if given, overrides defaultHiddenActivation
# for back compat, we must keep this value of 0.1 for defaultHiddenActivation even if it is strange
_PFValue   (pastOrFuture, dims, input, timeStep = 1, initialState = None, defaultHiddenActivation = 0.1, tag='', precision=precision) =
    new ComputationNode { operation = pastOrFuture + 'Value' ; inputs = _AsNodes (if BS.Constants.IsNone (initialState) then input else (input:initialState), precision=precision) ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ }
PastValue   (dims, input, timeStep = 1, initialState = None, defaultHiddenActivation = 0.1, tag='', precision=precision) = _PFValue ("Past",   dims, input, timeStep=timeStep, initialState=initialState, defaultHiddenActivation=defaultHiddenActivation, tag=tag, precision=precision)
FutureValue (dims, input, timeStep = 1, initialState = None, defaultHiddenActivation = 0.1, tag='', precision=precision) = _PFValue ("Future", dims, input, timeStep=timeStep, initialState=initialState, defaultHiddenActivation=defaultHiddenActivation, tag=tag, precision=precision)
Shift(input, fromOffset, boundaryValue, boundaryMode=-1/*context*/, dim=-1, tag='', precision=precision) = new ComputationNode [ operation = 'Shift' ; inputs = _AsNodes (input : boundaryValue, precision=precision) /*plus the function args*/ ]
RowSlice(beginIndex, numRows, input, tag='', precision=precision) = Slice(beginIndex, beginIndex + numRows, input, axis = 1, precision=precision)
RowRepeat(input, numRepeats, tag='', precision=precision) = new ComputationNode [ operation = 'RowRepeat' ; inputs = _AsNodes (input, precision=precision) /*plus the function args*/ ]
RowStack(inputs, axis=1, tag='', precision=precision) = new ComputationNode [ operation = 'RowStack' /*plus the function args*/ ]
EditDistanceError(leftInput, rightInput, subPen=1.0, delPen=1.0, insPen=1.0, squashInputs=false, tokensToIgnore=[||], tag='', precision=precision) = new ComputationNode [ operation = 'EditDistanceError' ; inputs = _AsNodes (leftInput : rightInput, precision=precision) /*plus the function args*/ ]
LatticeSequenceWithSoftmax(labels, evaluation, scaledLogLikelihood, lattice, symListPath, phonePath, stateListPath, transProbPath, latticeConfigPath = "LatticeNode.config", hSmoothingWeight = 0.95, frameDropThresh = 1e-10, doReferenceAlign = false, seqGammarUsesMBR = false, seqGammarAMF = 14.0, seqGammarLMF = 14.0, seqGammarBMMIFactor = 0.0, seqGammarWordPen = 0.0, tag='', precision=precision) = new ComputationNode [ operation = 'LatticeSequenceWithSoftmax' ; inputs = _AsNodes (labels : evaluation : scaledLogLikelihood : lattice, precision=precision) /*plus the function args*/ ]
ForwardBackward(graph, features, blankTokenId, delayConstraint=-1, tag='', precision=precision) = new ComputationNode [ operation = 'ForwardBackward' ; inputs = _AsNodes (graph : features, precision=precision) /*plus the function args*/ ]
LabelsToGraph(labels, tag='', precision=precision) = new ComputationNode [ operation = 'LabelsToGraph' ; inputs = _AsNodes (labels, precision=precision) /*plus the function args*/ ]
StopGradient(input, tag='', precision=precision) = new ComputationNode [ operation = 'StopGradient' ; inputs = _AsNodes (input, precision=precision) /*plus the function args*/ ]
Slice(beginIndex, endIndex, input, axis=1, tag='', precision=precision) =
    if axis < 0 then [ # time axis: specify -1
        beginFlags = if beginIndex > 0 then BS.Boolean.Not (BS.Loop.IsFirstN (beginIndex, input)) else                 BS.Loop.IsLastN  (-beginIndex, input)
        endFlags   = if endIndex   > 0 then                 BS.Loop.IsFirstN (endIndex,   input)  else BS.Boolean.Not (BS.Loop.IsLastN  (-endIndex,   input))
        flags = if      beginIndex == 0 then endFlags
                else if endIndex   == 0 then beginFlags
                else                         BS.Boolean.And (beginFlags, endFlags)
        out = if beginIndex == 0 && endIndex == 0
              then input
              else BS.Sequences.Gather (flags, input)
    ].out
    else new ComputationNode [ operation = 'Slice' ; inputs = _AsNodes (input, precision=precision) /*plus the function args*/ ] # non-time axis
Reshape(input, numRows, imageWidth = 0, imageHeight = 0, imageChannels = 0, tag='', precision=precision) = new ComputationNode [ operation = 'LegacyReshape' ; inputs = _AsNodes (input, precision=precision) /*plus the function args*/ ]
NewReshape(input, dims, beginAxis=0, endAxis=0, tag='', precision=precision) = new ComputationNode [ operation = 'Reshape' ; inputs = _AsNodes (input, precision=precision) ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]
ReshapeDimension(x, axis, tensorShape, precision=precision) = NewReshape(x, tensorShape, beginAxis=axis, endAxis=axis + 1, precision=precision)
FlattenDimensions(x, axis, num, precision=precision) = NewReshape(x, 0, beginAxis=axis, endAxis=axis + num, precision=precision)
SplitDimension(x, axis, N, precision=precision) = ReshapeDimension(x, axis, 0:N, precision=precision)
# TODO: make input the last arg!
Transpose(x, precision=precision) = TransposeDimensions(x, 1, 2, precision=precision)
LambdaRank(gain, prediction, queryId, tag='', precision=precision) = new ComputationNode [ operation = 'LambdaRank' ; inputs = _AsNodes (gain : prediction : queryId, precision=precision) /*plus the function args*/ ]
NDCG1Eval(gain, prediction, queryId, tag='', precision=precision) = new ComputationNode [ operation = 'NDCG1Eval' ; inputs = _AsNodes (gain : prediction : queryId, precision=precision) /*plus the function args*/ ]
Logistic(label, probability, tag='', precision=precision) = new ComputationNode [ operation = 'Logistic' ; inputs = _AsNodes (label : probability, precision=precision) /*plus the function args*/ ]
WeightedLogistic(label, probability, instanceWeight, tag='', precision=precision) = new ComputationNode [ operation = 'Logistic' ; inputs = _AsNodes (label : probability : instanceWeight, precision=precision) /*plus the function args*/ ]
ReconcileDynamicAxis(dataInput, layoutInput, tag='', precision=precision) = new ComputationNode [ operation = 'ReconcileDynamicAxis' ; inputs = _AsNodes (dataInput : layoutInput, precision=precision) /*plus the function args*/ ]
ReconcileMBLayout = ReconcileDynamicAxis # back compat
CastAs (type, data, precision=precision) = ReconcileDynamicAxis (data, type, precision=precision) # read as CastAs<type>(data) where the cast may consist of rearranging the data w.r.t. MBLayout or broadcasting across sequence items
# ND convo & pooling/unpooling   --why is autoPadding true? Normally one would want to reduce dimensions, no?
Convolution(weightNode, inputValueNode, kernelDims, mapDims = 0, stride = 1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, dilation = 1, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='', precision=precision) = new ComputationNode [ operation = 'Convolution' ; inputs = _AsNodes (weightNode : inputValueNode, precision=precision); kernelShape = new TensorShape [ dims = kernelDims ] ; mapCount = new TensorShape [ dims = mapDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimSharing = new BoolVector [ items = sharing ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] ; dimDilation = new TensorShape [ dims = dilation ] ; transpose = false; dimOutputShape = new TensorShape [ dims = 0 ]  /*plus the function args*/ ]
ConvolutionTranspose(weightNode, inputValueNode, kernelDims, mapDims = 0, stride = 1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, outputShape = None, dilation = 1, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='', precision=precision) = new ComputationNode [ operation = 'Convolution' ; inputs = _AsNodes (weightNode : inputValueNode, precision=precision); kernelShape = new TensorShape [ dims = kernelDims ] ; mapCount = new TensorShape [ dims = mapDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimSharing = new BoolVector [ items = sharing ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] ; dimDilation = new TensorShape [ dims = dilation ] ; transpose = true; dimOutputShape = new TensorShape [ dims = if BS.Constants.IsNone (outputShape) then 0 else outputShape ]  /*plus the function args*/ ]
Pooling(input, poolKind/*'max'|'average'*/, kernelDims, stride=1, autoPadding = true, lowerPad = 0, upperPad = 0, ceilOutDim = false, includePad = false, imageLayout='CHW', tag='', precision=precision) = new ComputationNode [ operation = 'Pooling' ; inputs = _AsNodes (input, precision=precision); pool = poolKind ; kernelShape = new TensorShape [ dims = kernelDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] ; ceilOut = ceilOutDim ; poolIncludePad = includePad /*plus the function args*/ ]
MaxUnpooling(unpoolInput, poolInput, kernelDims, stride=1, autoPadding = true, lowerPad = 0, upperPad = 0, imageLayout='CHW', tag='', precision=precision) = new ComputationNode [ operation = 'MaxUnpooling' ; inputs = _AsNodes (unpoolInput : poolInput, precision=precision); kernelShape = new TensorShape [ dims = kernelDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]
# 2D pooling
MaxPooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='', precision=precision) = new ComputationNode [ operation = 'MaxPooling' ; inputs = _AsNodes (input, precision=precision) /*plus the function args*/ ]
AveragePooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='', precision=precision) = new ComputationNode [ operation = 'AveragePooling' ; inputs = _AsNodes (input, precision=precision) /*plus the function args*/ ]
ROIPooling (input, ROIs, shape, spatialScale = 0.0625, precision=precision) = new ComputationNode { operation = 'ROIPooling' ; inputs = _AsNodes (input : ROIs, precision=precision) ; pool = 'max' ; roiOutputShape = new TensorShape [ dims = shape ] ; featureScale = spatialScale ; tag='' /*plus the function args*/ }
ColumnwiseCrossProduct = KhatriRaoProduct // deprecated
ErrorPrediction = ClassificationError   # legacy name
Delay = PastValue

Acos(x, tag='', precision=precision) = new ComputationNode [ operation = 'Acos' ; inputs = _AsNodes (x, precision=precision) /*plus the function args*/ ]
Asin(x, tag='', precision=precision) = new ComputationNode [ operation = 'Asin' ; inputs = _AsNodes (x, precision=precision) /*plus the function args*/ ]
Asinh(x, tag='', precision=precision) = new ComputationNode [ operation = 'Asinh' ; inputs = _AsNodes (x, precision=precision) /*plus the function args*/ ]
Atanh(x, tag='', precision=precision) = new ComputationNode [ operation = 'Atanh' ; inputs = _AsNodes (x, precision=precision) /*plus the function args*/ ]
BatchNormalization(input, scale, bias, runMean, runVariance, runCount=None, spatial=false, normalizationTimeConstant = 0, blendTimeConstant = 0, epsilon = 0.00001, useCntkEngine = true, disableRegularization = false, imageLayout='CHW', tag='', precision=precision) = new ComputationNode
{
    operation = 'BatchNormalization'
    inputs =
        if BS.Constants.IsNone(runCount)
        then _AsNodes (input : scale : bias : runMean : runVariance : ParameterTensor {(1), initValue = 0, learningRateMultiplier = 0, precision=precision}, precision=precision)
        else _AsNodes (input : scale : bias : runMean : runVariance : runCount, precision=precision)
    /*plus the function args*/
}
ClassBasedCrossEntropyWithSoftmax(labelClassDescriptorVectorSequence, mainInputInfo, mainWeight, classLogProbsBeforeSoftmax, tag='', precision=precision) = new ComputationNode [ operation = 'ClassBasedCrossEntropyWithSoftmax' ; inputs = _AsNodes (labelClassDescriptorVectorSequence : mainInputInfo : mainWeight : classLogProbsBeforeSoftmax, precision=precision) /*plus the function args*/ ]
Clip(minValue, maxValue, x, tag='', precision=precision) = new ComputationNode [ operation = 'Clip' ; inputs = _AsNodes (minValue : maxValue : x, precision=precision) /* plus the function args*/ ]
ColumnElementTimes(aVectorSequence, anotherVectorSequence, tag='', precision=precision) = new ComputationNode [ operation = 'ColumnElementTimes' ; inputs = _AsNodes (aVectorSequence : anotherVectorSequence, precision=precision) /*plus the function args*/ ]
// TODO: ColumnElementTimes = ElementTimes
CosDistance(aVectorSequence, anotherVectorSequence, tag='', precision=precision) = new ComputationNode [ operation = 'CosDistance' ; inputs = _AsNodes (aVectorSequence : anotherVectorSequence, precision=precision) /*plus the function args*/ ]
CosDistanceWithNegativeSamples(aVectorSequence, anotherVectorSequence, numShifts, numNegSamples, tag='', precision=precision) = new ComputationNode [ operation = 'CosDistanceWithNegativeSamples' ; inputs = _AsNodes (aVectorSequence : anotherVectorSequence : numShifts : numNegSamples, precision=precision) /*plus the function args*/ ]
Cosh(x, tag='', precision=precision) = new ComputationNode [ operation = 'Cosh' ; inputs = _AsNodes (x, precision=precision) /*plus the function args*/ ]
Cosine(x, tag='', precision=precision) = new ComputationNode [ operation = 'Cosine' ; inputs = _AsNodes (x, precision=precision) /*plus the function args*/ ]
CrossEntropy(refProbVectorSequence, outProbVectorSequence, tag='', precision=precision) = new ComputationNode [ operation = 'CrossEntropy' ; inputs = _AsNodes (refProbVectorSequence : outProbVectorSequence, precision=precision) /*plus the function args*/ ]
DiagTimes(diagonalMatrixAsColumnVector, matrix, tag='', precision=precision) = new ComputationNode [ operation = 'DiagTimes' ; inputs = _AsNodes (diagonalMatrixAsColumnVector : matrix, precision=precision) /*plus the function args*/ ]
// TODO: DiagTimes = ElementTimes
GatherPacked(indexSequence, sourceData, tag='', precision=precision) = new ComputationNode [ operation = 'GatherPacked' ; inputs = _AsNodes (indexSequence : sourceData, precision=precision) /*plus the function args*/ ]
GMMLogLikelihood(unnormalizedPriorVector, meansAsRows, logStdDevAsRows, dataVectorSequence, tag='', precision=precision) = new ComputationNode [ operation = 'GMMLogLikelihood' ; inputs = _AsNodes (unnormalizedPriorVector : meansAsRows : logStdDevAsRows : dataVectorSequence, precision=precision) /*plus the function args*/ ]
InvStdDev(dataVectorSequence, tag='', precision=precision) = new ComputationNode [ operation = 'InvStdDev' ; inputs = _AsNodes (dataVectorSequence, precision=precision) /*plus the function args*/ ]
KhatriRaoProduct(leftMatrix, rightMatrix, tag='', precision=precision) = new ComputationNode [ operation = 'KhatriRaoProduct' ; inputs = _AsNodes (leftMatrix : rightMatrix, precision=precision) /*plus the function args*/ ]
LogPlus(leftMatrix, rightMatrix, tag='', precision=precision) = new ComputationNode [ operation = 'LogPlus' ; inputs = _AsNodes (leftMatrix : rightMatrix, precision=precision) /*plus the function args*/ ]
LogSoftmax(z, tag='', precision=precision) = new ComputationNode [ operation = 'LogSoftmax' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ]
# TODO: ^^ along axis, like Softmax
MatrixL1Reg(matrix, tag='', precision=precision) = new ComputationNode [ operation = 'MatrixL1Reg' ; inputs = _AsNodes (matrix, precision=precision) /*plus the function args*/ ]
MatrixL2Reg(matrix, tag='', precision=precision) = new ComputationNode [ operation = 'MatrixL2Reg' ; inputs = _AsNodes (matrix, precision=precision) /*plus the function args*/ ]
Mean(dataVectorSequence, tag='', precision=precision) = new ComputationNode [ operation = 'Mean' ; inputs = _AsNodes (dataVectorSequence, precision=precision) /*plus the function args*/ ]
Negate(input, tag='', precision=precision) = new ComputationNode [ operation = 'Negate' ; inputs = _AsNodes (input, precision=precision) /*plus the function args*/ ]
PackedIndex(nodeWithLayoutToPackFor, indexSequence, tag='', precision=precision) = new ComputationNode [ operation = 'PackedIndex' ; inputs = _AsNodes (nodeWithLayoutToPackFor : indexSequence, precision=precision) /*plus the function args*/ ]
PerDimMeanVarDeNormalization(dataVectorSequence, meanVector, invStdDevVector, tag='', precision=precision) = new ComputationNode [ operation = 'PerDimMeanVarDeNormalization' ; inputs = _AsNodes (dataVectorSequence : meanVector : invStdDevVector, precision=precision) /*plus the function args*/ ]
#PerDimMeanVarNormalization(dataVectorSequence, meanVector, invStdDevVector, tag='', precision=precision) = new ComputationNode [ operation = 'PerDimMeanVarNormalization' ; inputs = _AsNodes (dataVectorSequence : meanVector : invStdDevVector, precision=precision) /*plus the function args*/ ]
PerDimMeanVarNormalization (x, mean, invStdDev, precision=precision) = ElementTimes(Minus(x, mean, precision=precision), invStdDev, precision=precision)
Reciprocal(z, tag='', precision=precision) = new ComputationNode [ operation = 'Reciprocal' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ]
//# the following is a temporary workaround until we have the C++ version
# TODO: change hiddenDims to hiddenShape and pass as a TensorShape (currently, the node only supports rank-1 data)
OptimizedRNNStack(weights, input, hiddenDims, numLayers=1, bidirectional=false, recurrentOp='lstm', axis=-1, tag='', precision=precision) = new ComputationNode [ operation = 'OptimizedRNNStack' ; inputs = _AsNodes (weights : input, precision=precision) /*plus the function args*/ ]
# legacy:
RNNStack(x, W, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='lstm', tag='', precision=precision) = OptimizedRNNStack(W, x, hiddenSize, numLayers=1, bidirectional=false, recurrentOp=rnnMode, tag='', precision=precision)
Scale(scalarScalingFactor, matrix, tag='', precision=precision) = new ComputationNode [ operation = 'Scale' ; inputs = _AsNodes (scalarScalingFactor : matrix, precision=precision) /*plus the function args*/ ]
# TODO: Scale = ElementTimes
ScatterPacked(cond, indexSequence, sourceData, tag='', precision=precision) = new ComputationNode [ operation = 'ScatterPacked' ; inputs = _AsNodes (cond : indexSequence : sourceData, precision=precision) /*plus the function args*/ ]
Sin(z, tag='', precision=precision) = new ComputationNode [ operation = 'Sin' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ]
Sinh(x, tag='', precision=precision) = new ComputationNode [ operation = 'Sinh' ; inputs = _AsNodes (x, precision=precision) /*plus the function args*/ ]
Softmax (z, axis=0, tag='', precision=precision) = 
    if axis == 0 then new ComputationNode [ operation = 'Softmax' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ]
    else
    [
        Z = ReduceLogSum (z, axis=axis, precision=precision) # reduce along axis
        P = Exp (z - Z, precision=precision)
    ].P
Hardmax(z, tag='', precision=precision) = new ComputationNode [ operation = 'Hardmax' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ]
Sqrt(z, tag='', precision=precision) = new ComputationNode [ operation = 'Sqrt' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ]
SquareError(aMatrix, anotherMatrix, tag='', precision=precision) = new ComputationNode [ operation = 'SquareError' ; inputs = _AsNodes (aMatrix : anotherMatrix, precision=precision) /*plus the function args*/ ]
SumColumnElements(z, tag='', precision=precision) = new ComputationNode [ operation = 'SumColumnElements' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ] # deprecated
SumElements(matrix, tag='', precision=precision) = new ComputationNode [ operation = 'SumElements' ; inputs = _AsNodes (matrix, precision=precision) /*plus the function args*/ ]
# ^^ TODO: Rename to ReduceSumMB?
Tanh(z, tag='', precision=precision) = new ComputationNode [ operation = 'Tanh' ; inputs = _AsNodes (z, precision=precision) /*plus the function args*/ ]
TimeReverse(vectorSequence, tag='', precision=precision) = new ComputationNode [ operation = 'TimeReverse' ; inputs = _AsNodes (vectorSequence, precision=precision) /*plus the function args*/ ]
Trace (node, say='', logFrequency=100, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[], tag='', precision=precision) = new ComputationNode [ operation = 'Trace' ; inputs = _AsNodes (node, precision=precision) ]
TransposeTimes(leftMatrix, rightMatrix, tag='', precision=precision) = new ComputationNode [ operation = 'TransposeTimes' ; inputs = _AsNodes (leftMatrix : rightMatrix, precision=precision) /*plus the function args*/ ]
QuantizedTimes(leftMatrix, rightMatrix, bitSmoothingA=1, bitSmoothingB=1, outputRank=1, inferInputRankToMap=-1, tag='', precision=precision) = new ComputationNode [ operation = 'QuantizedTimes' ; inputs = _AsNodes (leftMatrix : rightMatrix, precision=precision) /*plus the function args*/ ]
Where(cond, tag='', precision=precision) = new ComputationNode [ operation = 'Where' ; inputs = _AsNodes (cond, precision=precision) /*plus the function args*/ ]

Cast(node, tag='', precision='') = new ComputationNode [ operation = 'Cast' ; inputs = _AsNodes (node) /*plus the function args*/ ]

##############################################################################
# non-neural-network functions
##############################################################################

Print(value, format='') = new PrintAction [ what = value /*; how = format*/ ]
Fail(what) = new FailAction [ /*what*/ ]
Format(value, format) = new StringFunction [ what = 'Format' ; arg = value ; how = format ]
Replace(s, from, to) = new StringFunction [ what = 'Replace' ; arg = s ; replacewhat = from ; withwhat = to ]
Substr(s, begin, num) = new StringFunction [ what = 'Substr' ; arg = s ; pos = begin ; chars = num ]
Chr(c) = new StringFunction [ what = 'Chr' ;  arg = c ]
Length(x) = new NumericFunction [ what = 'Length' ; arg = x ]
Repeat (N, what) = if N <= 0 then BS.Constants.None else (Repeat (N-1, what) : what) # can also be used to turn a scalar into a 1-element array
_ForceResizeArray (N, arrayOrScalar) = { # bring an array to a given length, either by chopping or by duplicating its last value
    arr = _AsArray (arrayOrScalar)
    L = Length (arr)
    res = if N < L then array[0..N-1] (i => arr[i]) # chop to length
          else if L == 0 then Fail ("BottomlessExpansion(): needs at least one element to expand.")
          else _ConcatArrays (arr, Repeat (N-L, arr[L-1])) # append copies of the last value
}.res
_AsArray (x) = if IsArray (x) then x else [| x |] # helper to allow dimensions to describe scalars (42) or tensors (13:42)
_ConcatArrays (aOrScalar, bOrScalar) = {
    a = _AsArray (aOrScalar) ; b = _AsArray (bOrScalar)
    newLen = Length (a)+Length(b)
    res = if newLen == 0 then BS.Constants.None else array[0..newLen-1] (i => if i < Length (a) then a[i] else b[i-Length (a)])
}.res
Sign(x) = if x > 0 then 1 else if x < 0 then -1 else 0
Min(a,b) = if a < b then a else b
Max(a,b) = if a > b then a else b
Fac(n) = if n > 1 then Fac(n-1) * n else 1
IsSameObject(a,b) = new CompareFunction [ what = 'IsSameObject' ; args = (a : b) ]
IsArray(a)  = new CompareFunction [ what = 'IsArray'  ; args = a ]
IsDouble(a) = new CompareFunction [ what = 'IsDouble' ; args = a ]
IsBool(a)   = new CompareFunction [ what = 'IsBool'   ; args = a ]
Mod(x, y)  = new NumericFunction [ what = 'Mod' ;  args = (x:y) ]
IntDiv(x, y) = new NumericFunction [ what = 'IntDiv' ;  args = (x:y) ]

##############################################################################
# macros from NDL book
##############################################################################

# deprecated--use LinearLayer{} and DenseLayer{} instead
BFF(in, rows, cols) = [ B = Parameter(rows, 1, initValue = 0) ; W = Parameter(rows, cols) ; z = W*in+B ]
SBFF(in, rows, cols) = [ Eh = Sigmoid(BFF(in, rows, cols).z) ]

# deprecated--use FeatureMVNLayer{} instead
MeanVarNorm(feat) = PerDimMeanVarNormalization(feat, Mean(feat), InvStdDev(feat))

# deprecated--use LogPriorLayer{} instead
LogPrior(labels) = Log(Mean(labels))

# specify one of these two for initialization:
#  - init = "uniform"|"gaussian"
#  - embeddingFile = PATHNAME
# deprecated--use EmbeddingLayer{} instead
Embedding (embeddingDim, input, inputDim=input.dim, initFrom=''/*|fromFile|gaussian|uniform*/, embeddingPath = '', sparseInput = false, learningRateWeight = 0.0) = [
    embedding = Transpose (Parameter (inputDim, embeddingDim, learningRateMultiplier = learningRateWeight, init = initFrom, initFromFilePath = embeddingPath))
    lookup = if sparseInput then embedding * input
             else GatherPacked (input, embedding)
].lookup

##############################################################################
# the more specific standard things are in a namespace called 'BS'
# You can create shorthands for accessing these, e.g. saying B = BS.Boolean.
# Note: Identifiers beginning with _ should be considered for library use only.
##############################################################################

BS = [

##############################################################################
# Basic constants
##############################################################################

Constants = [
    Zero = ConstantTensor (0, (1))
    One  = ConstantTensor (1, (1))
    OnesTensor (dims) = ConstantTensor (1, dims)
    # A constant 1D tensor that contains an integer enumeration 0 .. (dims - 1).
    Range (dim) = Splice( array [0..(dim-1)] ( i => Constant{i} ) )
    # BUGBUG: ZeroesLike() would recreate the full dimension of x. Well, no need if it considers broadcasting. But still wrong if we want to broadcast a vector of different tensor dim.
    #ZeroesLike (x) = CastAs (x, Zero) // read: Cast<x>(Zero)
    #OnesLike (x)   = CastAs (x, One)
    # CastAs() does not implement broadcasting
    ZeroesLike (x) = SumColumnElements (RowSlice (0, 1, x) .* Zero)  // hack: get one row of input and multiply with zero; double-hack: reduce extra tensor dims by SumCol
    ZeroSequenceLike = ZeroesLike   # TODO: this should yield a scalar sequence, while ZeroesLike should be a tensor
    ZeroesLike1 (x) = x .* Zero     # get a tensor of zeroes of same dim as x  TODO: Do this as a C++ node (will be simple)
    OnesLike (x) = ZeroesLike (x) + One
    # is this like Sequences.Repeat?
    True  = 1
    False = 0
    None = [| |]  # doubles up as an empty array. Note: only use [| |] syntax inside here, as it may change in the future
    IsNone (x) = IsSameObject (x, None)
]

##############################################################################
# Boolean operations
# These operations will have undefined behavior for input values != 0 or 1.
##############################################################################

# boolean helpers
Boolean = [
    True  = 1
    False = 0

    # basic logical operations
    And (a,b) =         a .* b
    Or  (a,b) = a + b - a .* b
    Xor (a,b) = a + b - a .* b * Constant (2)
    Not (x)   = Constants.One - x

    # on each time step where clk 1, this toggles its value
    Toggle (clk, initialValue=False) = [
        state = Xor (PastValue (1, state, defaultHiddenActivation=initialValue), clk)
    ].state

    # select a value
    # Note: This will be replaced by BrainScript 'if cond then thenVal else elseVal' and SwitchNode
    If (cond, thenVal, elseVal, tag='') =  new ComputationNode [ operation = 'If' ; inputs = _AsNodes (cond : thenVal : elseVal) /*plus the function args*/ ]
]

##############################################################################
# sequence operations
# These mimic LINQ operations.
##############################################################################

Sequences = [
    # broadcast a single-step sequence to a multi-step sequence
    BroadcastSequenceAs (type, data1) = [                      # type=example sequence with desired length (outside of a loop), data1=1 time step
        # BUGBUG: This should work but gives worse results.
        #ZeroSequenceLike (x) = RowSlice (0, 1, x) .* Constants.Zero # BUGBUG: SumColumnElements() has a CPU/GPU problem
        #index = /*Constants.*/ZeroSequenceLike (type)  # create an index sequence [ 0 0 0 ... ] of target length
        #packedIndex = PackedIndex (data1, index)       # convert into internal packed index w.r.t. 'data1'
        #out = GatherPacked (packedIndex, data1)        # copy data1[0] to all elements, total length like 'type'

        # alternative (slower, older) implementation (10% slower end-to-end?)
        # Gives nearly the same result, but not completely. Since Gather() above has an atomicAdd(), let's leave this on for now and check later.
        dataPadded = Sequences.Scatter (Loop.IsFirst (type), data1) # padded with zeroes until end of target sequence
        out = Boolean.If (Loop.IsFirst (dataPadded), # if first entry
                 /*then*/ dataPadded,                # then copy that
                 /*else*/ Loop.Previous (out))       # else just propagate to the front
    ].out

    # rolling window over past N samples
    # returns a record [ value=..., valid=... ], both being 1-step sequences of [K x N] (K=dim for value and 1 for valid). N can optionally be moved to axes >2.
    # This implementation is suboptimal in that it creates copies for the intermediate steps.
    PastValueWindow (N, in, axis=2) = [
        isLast = Loop.IsLast (in)
        isLastIndex = PackedIndex (in, Where (isLast))
        GatherLast (x) = GatherPacked (isLastIndex, x) # 'cond' matches 'x'
        onesLikeIn = Constants.OnesLike (in)
        delayLine[t:0..N-1] = [     # shift register for encoder, last N inputs
            value = if t == 0
                    then in         # delay 0: current value
                    else PastValue (0, in, timeStep=t, defaultHiddenActivation=0)
            valid = if t == 0
                    then onesLikeIn   # BUGBUG: if I say Constant.Ones here, it outputs 0. Ones has no MBLayout
                    else PastValue (1, onesLikeIn, timeStep=t, defaultHiddenActivation=0)


            TraceDenseTransposed (h, what) = h
            #    Trace (h, say=what, logFirst=10, logFrequency=100, logGradientToo=false, onlyUpToRow=9, onlyUpToT=25, format=[ type = "real" ; transpose = true ; precisionFormat = ".4" ])


            lastValue = TraceDenseTransposed(  GatherLast (value)  ,'dvalue')  # [i, delay]
            lastValid = TraceDenseTransposed(  GatherLast (valid)  ,'dvalid')  # [i, delay]
        ]
        # delayLine[t].value = value of t steps in the past
        # delayLine[t].valid = true if we had a value t steps in the past
        SplitStack (x) =
            if      axis == 2 then SplitDimension (x, 1, N)
            else if axis > 2  then TransposeDimensions (SplitDimension (x, 1, N), 2, axis)
            else Fail ("PastValueWindow: axis>2 required.") # BUGBUG: We also require that input is a single vector. Address later.
        value = SplitStack (RowStack (array[0..N-1](t=>delayLine[t].lastValue)))  # [i, delay]
        valid = SplitStack (RowStack (array[0..N-1](t=>delayLine[t].lastValid)))  # [i, delay]
    ]

    # fold left/right: Reduce entire sequence by applying binaryOp, e.g. FoldL (Plus, 0, input)
    # LINQ calls this Aggregate; and may or may not specify the seed value; and allows a predicate
    FoldL (binaryOp, x0, x) = _Fold (PastValue,   binaryOp, x0, x)
    FoldR (binaryOp, x0, x) = _Fold (FutureValue, binaryOp, x0, x)
    _Fold (binaryOp, x0, x) = [
        acc = binaryOp (x, if Loop.IsFirst (x) then x0 else PastValue (acc))
        out = Last (acc)
    ].out
    # TODO: need a version that does not require an initial value--what would that be called?

    # LINQ-like operators
    Map (lambda, x) = lambda (x)     // that one's easy
    # Reverse (x) is a C++ node currently called TimeReverse

    # Gather and Scatter
    # We go through 3 nodes each to take advantage of x
    Gather  (cond, x) =  GatherPacked (      PackedIndex (/*layout of*/ x, Where (cond)), x)                  # 'cond' matches 'x'
    Scatter (cond, y) = ScatterPacked (/*layout of*/ cond, PackedIndex (/*layout of*/ cond, Where (cond)), y) # 'cond' matches the result

    # sequence-altering LINQ-like operators
    # These generate new data packing (MBLayouts)

    # First and Take
    # LINQ allows predicates as well.
    First (x) = Slice (0,  1, x,  axis=-1)
    Last (x)  = Slice (-1, 0, x,  axis=-1)

    # TakeWhile and DropWhile
    #TakeWhile (predicate, x) = Filter ( _WhilePredicate (PastValue, predicate), x)
    #SkipWhile (predicate, x) = Filter (!_WhilePredicate (PastValue, predicate), x)
    #_WhilePredicate (DelayFn, predicate, input) =
    #[
    #    whilePredicateRec = Boolean.And (DelayFn (whilePredicateRec, defaultHiddenActivation=Boolean.True), predicate)
    #].whilePredicateRec
    # TODO: do we need operations from the back?

    #Take (N, x) = _Take (PastValue, N, x)
    #TakeRight (N, x) = _Take (FutureValue, N, x)
    #_Take (DelayFn, N, x) = [
    #    selected = Loop._IsWithin (DelayFn, N, x)
    #    out = Gather (selected, x)
    #].out
    #
    #Skip (N, x) = if N > 0 then _Skip (PastValue, N, x) else x
    #_Skip (DelayFn, N, x) = [ // TODO: merge with _Take
    #    selected = Loop._IsWithin (DelayFn, N, x)
    #    out = Gather (Boolean.Not (selected), x)
    #].out
    #ElementAt (n, x) = [ // not efficient, as it filters twice. Better AND the predicates. TODO: what if n is out of range? ElementAtOrDefault
    #    startMask = Skip (n, x)                     // ...000111...
    #    mask = startMask - PastValue (0, startMask) // ...000100...
    #    out = Gather (mask, x)
    #]
    #Single (predicate, x) = x

    #FirstOrDefault (x) = ? // can empty sequences exist or even be represented by CNTK?

    #Average (x) = Sum (x) / Loop.Count(x)  // TODO: patch opQuotient to check 0/0 = 0
    #Sum (x)    = FoldL (Plus,    0, x)
    #LogSum (x) = FoldL (LogPlus, 0, x)
    #Max (x) = FoldL (^.Max, ?, x) // TODO: name clash; need to implement ^.
    #Min (x) = FoldL (^.Min, ?, x) // TODO: what's the init value?
    #All (x) = FoldL (Boolean.And,  OnesLike (x), x)
    #Any (x) = FoldL (Boolean.Or, ZeroesLike (x), x)

    # Join to create 2D fields for s2s attention?

    # Concat  (a Zip but in sequence dimension)
]

##############################################################################
# index operations
# These refer to the loop iteration itself.
##############################################################################

Loop = {
    # get the current iteration index w.r.t a node in a loop, such as a Delay node
    Iteration (x) = [
      agg = OnesLike (x) + PastValue(agg, defaultHiddenActivation=0) // a recurrence that sums up ones
    ].agg

    # get the total length of a sequence
    # TODO: in LINQ, this is an aggregation operation, so it would be long into Sequences
    Count(x) = Sequences.Last (1, Iteration (x)) // take last item of recurrence that sums up ones

    # is the current iteration the first/last of the loop?
    IsFirst (x) = _IsWithin (PastValue,   1, x)
    IsLast (x)  = _IsWithin (FutureValue, 1, x)

    IsFirstN (N, x) = _IsWithin (PastValue,   N, x)
    IsLastN  (N, x) = _IsWithin (FutureValue, N, x)

    # private helpers
    # flag whether a frame is within the first or last N frames
    _IsWithin (DelayFn/*PastValue or FutureValue*/, N, x) = DelayFn (0, Constants.ZeroesLike (x)/*false*/, timeStep=N, defaultHiddenActivation=Constants.True)

    # opposite of Id's "next x = ..."
    Previous (x, initialState=None) = PastValue   (0, x, initialState=initialState, timeStep=1)
    Next     (x, initialState=None) = FutureValue (0, x, initialState=initialState, timeStep=1)

    PreviousOrDefault (x, defaultValue=Constant (0)) =   # a delay node with initial value  --TODO: merge the two, then do in C++
    [
        flags = BS.Loop.IsFirst (x)
        out = BS.Boolean.If (flags,
                    /*then*/ BS.Sequences.Scatter (flags, defaultValue),
                    /*else*/ Previous (x))
    ].out

    NextOrDefault (x, defaultValue=Constant (0)) =   # a delay node with initial value
    [
        flags = BS.Loop.IsLast (x)
        out = BS.Boolean.If (flags,
                    /*then*/ BS.Sequences.Scatter (flags, defaultValue),
                    /*else*/ Next (x))
    ].out
}

##############################################################################
# parameter definitions
##############################################################################

Parameters =
[
    # TODO: These all have randomSeed set to 1!
    WeightParam (outputDim, inputDim) = ParameterTensor ((outputDim : inputDim), init='uniform', initValueScale=1, initOnCPUOnly=true, randomSeed=1)
    DiagWeightParam (outputDim)       = ParameterTensor ((outputDim), init='uniform', initValueScale=1, initOnCPUOnly=true, randomSeed=1) # meant to be applied elementwise
    BiasParam (dim)                   = ParameterTensor ((dim), initValue=0.0)
    ScalarParam()                     = BiasParam (1)

    # route input through an extra weight, for stabilization
    StabilizeElements (x, inputDim=x.dim, enabled=true) =
        if enabled
        then [
            #beta = Exp (BiasParam ((inputDim))) # init value is 0
            #beta = ParameterTensor ((inputDim), initValue=1.0) # init value is 1
            # or SoftPlus: ln(1+e^beta)
            #beta = Log (Constants.One + Exp (ParameterTensor ((inputDim), initValue=0.54132485/*ln (e-1)*/))) # init value is 1

            # sharpened Softplus: 1/f ln(1+e^{f*beta})
            # this behaves linear for weights around 1, yet guarantees positiveness

            f = ConstantTensor (4, (1))
            fInv = Reciprocal (f)
            beta = fInv .* Log (Constants.One + Exp (f .* ParameterTensor ((inputDim), initValue=0.99537863/* 1/f*ln (e^f-1) */))) # init value is 1

            TraceDense (h, what) = h  # delete h and uncomment Trace to trace the beta values. They are a valuable indicator.
                //Trace (h, say=what, logFirst=10, logFrequency=100, logGradientToo=false, onlyUpToRow=9, onlyUpToT=25, format=[ type = "real" ; transpose = false ; precisionFormat = ".6" ])

            result = TraceDense (    beta,    'beta') .* x
        ].result
        else x

    # and the same with a scalar stabilizer shared across all components
    Stabilize (x, enabled=true) = if enabled then StabilizeElements (x, inputDim=1, enabled=true) else x
]

##############################################################################
# recurrent networks
##############################################################################

RNNs =
[
    # LSTMBlock -- LSTM object with projection and self-stabilization
    # Projection is enabled by passing different values for outputDim and cellDim.
    # This is the stateless version that takes the previous state as an input.
    # It returns a dictionary with three members: h and c, and dim=h.dim for convenience. prevState must have h and c.
    # This function also takes an optional auxiliary input, e.g. for supporting attention models.
    LSTMBlock (outputDim, cellShape=None, usePeepholes=false, init='glorotUniform', initValueScale=1, enableSelfStabilization=false) =
    {
        cellDim = if Constants.IsNone (cellShape) then outputDim else cellShape

        # parameter helpers
        # we group 4 matrices into one, gives 60% speed-up in some cases
        # Not grouping inputs since random-init scaling would be incorrect.
        # TODO: Remove the function call, just assign it.
        B  = ParameterTensor {(4 * cellDim),             initValue=0}       # a bias
        W  = ParameterTensor {(4 * cellDim : Inferred),  init=init, initValueScale=initValueScale}   # input
        A  = ParameterTensor {(4 * cellDim : Inferred),  init=init, initValueScale=initValueScale}   # aux input (optional)
        H  = ParameterTensor {(4 * cellDim : outputDim), init=init, initValueScale=initValueScale}   # hidden-to-hidden
        Ci = ParameterTensor {(    cellDim),             init=init, initValueScale=initValueScale}   # cell-to-hiddden {note: applied elementwise}
        Cf = ParameterTensor {(    cellDim),             init=init, initValueScale=initValueScale}   # cell-to-hiddden {note: applied elementwise}
        Co = ParameterTensor {(    cellDim),             init=init, initValueScale=initValueScale}   # cell-to-hiddden {note: applied elementwise}

        Wmr = ParameterTensor {(outputDim : cellDim), init=init, initValueScale=initValueScale};  # final projection

        Sdh = if enableSelfStabilization then StabilizerLayer{} else Identity
        Sdc = if enableSelfStabilization then StabilizerLayer{} else Identity
        Sct = if enableSelfStabilization then StabilizerLayer{} else Identity
        Sht = if enableSelfStabilization then StabilizerLayer{} else Identity

        apply (x, prevState, aux=None) = {
            _ = {     // encapsulate the inner workings

                dh = prevState.h // previous values
                dc = prevState.c

                dhs = Sdh(dh) // previous values, stabilized
                dcs = Sdc(dc)
                # note: input does not get a stabilizer here, user is meant to do that outside

                # projected contribution from input(s), hidden, and bias
                proj4 = if Constants.IsNone (aux)
                        then B + W * x + H * dhs
                        else B + W * x + H * dhs + A * aux

                itProj  = Slice (0*cellDim, 1*cellDim, proj4, axis=1)
                bitProj = Slice (1*cellDim, 2*cellDim, proj4, axis=1)
                ftProj  = Slice (2*cellDim, 3*cellDim, proj4, axis=1)
                otProj  = Slice (3*cellDim, 4*cellDim, proj4, axis=1)

                # add peephole connection if requested
                peep(x, c, C) = if usePeepholes then x + C .* c else x

                it = Sigmoid (peep (itProj, dcs, Ci))        // input gate(t)
                bit = it .* Tanh (bitProj)                   // applied to tanh of input network

                ft = Sigmoid (peep (ftProj, dcs, Cf))        // forget-me-not gate(t)
                bft = ft .* dc                               // applied to cell(t-1)

                ct = bft + bit                               // c(t) is sum of both

                ot = Sigmoid (peep (otProj, Sct(ct), Co))    // output gate(t)
                ht = ot .* Tanh (ct)                         // applied to tanh(cell(t))
            }

            # our return values
            c = _.ct          // cell value
            h = if outputDim != cellDim   // output/hidden state
                then Wmr * Sht(_.ht)      // project
                else _.ht                 // no projection
            dim = outputDim
        } // end of apply (x, prevState)
    }.apply

    # LSTMP -- LSTM function with projection and self-stabilization
    # Projection is enabled by passing different values for outputDim and cellDim.
    # This is the stateless version that takes the previous state as an input.
    # It returns a dictionary with three members: h and c, and dim=h.dim for convenience. prevState must have h and c.
    # This function also takes an optional auxiliary input, e.g. for supporting attention models.
    LSTMP (outputDim, cellDim=outputDim, x, aux=Constants.None, auxDim=aux.dim, prevState, enableSelfStabilization=false, inputDim=0) =
    [
        # TODO: Implement this in terms of the one above. Needs to be tested.
        S(x) = Parameters.Stabilize (x, enabled=enableSelfStabilization)

        _ = [     // encapsulate the inner workings

            // parameter macros
            # note: each invocation comes with its own set of weights
            B() = Parameters.BiasParam (cellDim)
            W() = Parameters.WeightParam (cellDim, 0)               // input
            A() = Parameters.WeightParam (cellDim, auxDim)          // aux input
            H() = Parameters.WeightParam (cellDim, outputDim)       // hidden-to-hidden
            C() = Parameters.DiagWeightParam (cellDim)              // cell-to-hiddden (note: applied elementwise)

            dh = prevState.h // previous values
            dc = prevState.c

            dhs = S(dh) // previous values, stabilized
            dcs = S(dc)
            # note: input does not get a stabilizer here, user is meant to do that outside

            # projected contribution from input(s) and bias
            pin() = if Constants.IsNone (aux)
                    then B() + W() * x
                    else B() + W() * x + A() * aux

            it = Sigmoid (pin() + H() * dhs + C() .* dcs)           // input gate(t)
            bit = it .* Tanh (pin() + H() * dhs)                    // applied to tanh of input network

            ft = Sigmoid (pin() + H() * dhs + C() .* dcs)           // forget-me-not gate(t)
            bft = ft .* dc                                          // applied to cell(t-1)

            ct = bft + bit                                          // c(t) is sum of both

            ot = Sigmoid (pin() + H() * dhs + C() .* S(ct))         // output gate(t)
            ht = ot .* Tanh (ct)                                    // applied to tanh(cell(t))
        ]

        # our return values
        c = _.ct                        // cell value
        h = if outputDim != cellDim     // output/hidden state
            then [                      // project
                Wmr = Parameters.WeightParam (outputDim, cellDim);
                htp = Wmr * S(_.ht)
            ].htp
            else _.ht                   // no projection
        dim = outputDim
    ]

    # helper function to delay h and c
    # Callers can provide their own, e.g. useful for beam decoding.
    PreviousHC (lstmState, initialState=None, layerIndex=0) = {
       h = Loop.Previous (lstmState.h, initialState=if BS.Constants.IsNone (initialState) then None else initialState.h)  # hidden state(t-1)
       c = Loop.Previous (lstmState.c, initialState=if BS.Constants.IsNone (initialState) then None else initialState.c)  # cell(t-1)
       dim = lstmState.dim
    }

    # pass previousHook=BS.RNNs.NextHC instead of PreviousHC to get a right-to-left recurrence
    NextHC (lstmState, initialState=None, layerIndex=0) = {
       h = Loop.Next (lstmState.h, initialState=if BS.Constants.IsNone (initialState) then None else initialState.h)  # hidden state(t-1)
       c = Loop.Next (lstmState.c, initialState=if BS.Constants.IsNone (initialState) then None else initialState.c)  # cell(t-1)
       dim = lstmState.dim
    }

    PreviousHCWithTrainedInitialState{shape=(0)} = {  # default (0) will infer to all elements for inputs of rank 0
        initialH = ParameterTensor {shape, initValue=0}
        initialC = ParameterTensor {shape, initValue=0}
        apply (x) = PreviousHC (x, initialState={ h = initialH; c = initialC }, layerIndex=0)
    }.apply

    NextHCWithTrainedInitialState{shape=(0)} = {  # default (0) will infer to all elements for inputs of rank 0
        initialH = ParameterTensor {shape, initValue=0}
        initialC = ParameterTensor {shape, initValue=0}
        apply (x) = NextHC (x, initialState={ h = initialH; c = initialC }, layerIndex=0)
    }.apply

    NoAuxInputHook (input, lstmState) = Constants.None

    # this implements a recurrent (stateful) LSTM with projection and self-stabilization
    # It returns a record (h,c). To use its output, say .h
    # By default, this is left-to-right. Pass previousHook=BS.RNNs.NextHC for a right-to-left model.
    RecurrentLSTMP (outputDim/*h.dim*/, cellDim=BS.Constants.None,
                    x, inputDim=0,
                    previousHook=BS.RNNs.PreviousHC,
                    augmentInputHook=NoAuxInputHook, augmentInputDim=0,
                    layerIndex=0,
                    enableSelfStabilization=false) =
    [
        enableSelfStabilization1 = enableSelfStabilization ; cellDim1 = cellDim ; layerIndex1 = layerIndex # workaround

        prevState = previousHook (lstmState, layerIndex=layerIndex1) # recurrent memory. E.g. Previous or Next, with or without initial state, beam reordering etc.

        auxInput = augmentInputHook(x, prevState)   # optionally augment input. Constants.None if none.

        lstmState = BS.RNNs.LSTMP (outputDim, cellDim=if BS.Constants.IsNone (cellDim) then outputDim else cellDim, x, inputDim=0, aux=auxInput, auxDim=augmentInputDim, prevState, enableSelfStabilization=enableSelfStabilization1)
    ].lstmState // that's the value we return

    # a stack of recurrent LSTMs (unidirectional)
    RecurrentLSTMPStack (layerDims, cellDims=BS.Constants.None,
                         input, inputDim=0,
                         previousHook=PreviousHC,
                         augmentInputHook=NoAuxInputHook, augmentInputDim=0,
                         enableSelfStabilization=false) =
    [
        previousHook1 = previousHook ; useStabilizer = enableSelfStabilization ; augmentInputHook1 = augmentInputHook ; augmentInputDim1 = augmentInputDim
        layers[i:0..Length (layerDims)-1] =
            RecurrentLSTMP (layerDims[i], cellDim=if BS.Constants.IsNone (cellDims) then layerDims[i] else cellDims[i],
                            if i == 0 then input else Parameters.Stabilize (layers[i-1].h, enabled=useStabilizer),
                            previousHook=previousHook1,
                            augmentInputHook=if i == 0 then augmentInputHook1 else NoAuxInputHook, augmentInputDim=if i == 0 then augmentInputDim1 else 0,
                            layerIndex=i,
                            enableSelfStabilization=useStabilizer)
    ].layers

    # a stack of recurrent LSTMs (bidirectional)
    # TODO: Should we define layerDims as the total (sum of both forward and backward direction)?
    RecurrentBidirectionalLSTMPStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
        previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization
        layers[i:0..Length (layerDims)-1] =
        [
            v    = if i == 0 then input    else Parameters.Stabilize (layers[i-1].h, enabled=useStabilizer)
            vDim = if i == 0 then inputDim else                       layers[i-1].dim
            fwd = RecurrentLSTMP (layerDims[i], cellDim=cellDims[i],
                                  v, inputDim=vDim,
                                  previousHook=previousHook1,
                                  layerIndex=i,
                                  enableSelfStabilization=useStabilizer)
            bwd = RecurrentLSTMP (layerDims[i], cellDim=cellDims[i],
                                  v, inputDim=vDim,
                                  previousHook=nextHook1,
                                  layerIndex=i,
                                  enableSelfStabilization=useStabilizer)
            h = Splice ((fwd.h : bwd.h), axis=1)
            c = Splice ((fwd.c : bwd.c), axis=1)
            dim = layerDims[i] * 2  # output dimension
        ]
    ].layers

    # NOTE: the GRU implementation below has too much code duplication with the LSTM functions; it will be re-written
    # GRU -- GRU function with projection and self-stabilization
    # It returns a dictionary with three members: the hidden state h, the cell state c, and dim=h.dim.
    # While c isn't required, we return it for implementations like seq2seq that expect it so that this can be a proper drop-in replacement for LSTM
    # Like the LSTM function, it also takes an optional auxiliary input, e.g. for supporting attention models.
    GRU (outputDim, cellDim=outputDim, x, inputDim=x.dim, aux=Constants.None, auxDim=aux.dim, prevState, enableSelfStabilization=false) =
    [
        S(x) = Parameters.Stabilize (x, enabled=enableSelfStabilization)
        cellDim = outputDim

        _ = [     // encapsulate the inner workings

            dh = prevState.h   // previous value
            dhs = S(dh)        // previous value, stabilized
            # note: input does not get a stabilizer here, user is meant to do that outside

            // parameter macros
            # note: each invocation comes with its own set of weights
            B() = Parameters.BiasParam (cellDim)
            W() = Parameters.WeightParam (cellDim, inputDim)        // input
            A() = Parameters.WeightParam (cellDim, auxDim)          // aux input
            H() = Parameters.WeightParam (cellDim, outputDim)       // hidden-to-hidden

            # projected contribution from input(s) and bias
            pin() = if Constants.IsNone (aux)
                    then B() + W() * x
                    else B() + W() * x + A() * aux

            # update gate z(t)
            zt = Sigmoid (pin() + H() * dhs)

            # reset gate r(t)
            rt = Sigmoid (pin() + H() * dhs)

            # "cell" c
            rs = dhs .* rt
            ct = Tanh (pin() + H() * rs)

            # hidden state ht / output
            ht = (BS.Constants.OnesTensor (cellDim) - zt) .* ct + zt .* dhs
        ]

        # our return values (projection)
        h = if outputDim != cellDim     // output/hidden state
        then [                          // project
            Wmr = Parameters.WeightParam (outputDim, cellDim);
            htp = Wmr * S(_.ht)
        ].htp
        else _.ht                       // no projection

        c = _.ct
        dim = outputDim
    ]

    # this implements a recurrent (stateful) GRU with self-stabilization
    # It returns a record (h,c) to be compatible with the LSTM version. To use its output, say .h
    # By default, this is left-to-right. Pass previousHook=BS.RNNs.NextHC for a right-to-left model.
    RecurrentGRU   (outputDim/*h.dim*/, cellDim=outputDim,
                    x, inputDim=x.dim,
                    previousHook=BS.RNNs.PreviousHC,
                    augmentInputHook=NoAuxInputHook, augmentInputDim=0,
                    layerIndex=0,
                    enableSelfStabilization=false) =
    [
        enableSelfStabilization1 = enableSelfStabilization ; cellDim1 = cellDim ; inputDim1 = inputDim ; layerIndex1 = layerIndex # workaround

        prevState = previousHook (gruState, layerIndex=layerIndex1) # recurrent memory. E.g. Previous or Next, with or without initial state, beam reordering etc.

        auxInput = augmentInputHook(x, prevState)   # optionally augment input. Constants.None if none.

        gruState = BS.RNNs.GRU (outputDim, cellDim=cellDim1, x, inputDim=inputDim1, aux=auxInput, auxDim=augmentInputDim, prevState, enableSelfStabilization=enableSelfStabilization1)
    ].gruState // that's the value we return

    # a stack of recurrent GRUs (unidirectional)
    RecurrentGRUStack (layerDims, cellDims=layerDims,
                       input, inputDim=input.dim,
                       previousHook=PreviousHC,
                       augmentInputHook=NoAuxInputHook, augmentInputDim=0,
                       enableSelfStabilization=false) =
    [
        previousHook1 = previousHook ; useStabilizer = enableSelfStabilization ; augmentInputHook1 = augmentInputHook ; augmentInputDim1 = augmentInputDim
        layers[i:0..Length (layerDims)-1] =
            RecurrentGRU   (layerDims[i], cellDim=cellDims[i],
                            if i == 0 then input else Parameters.Stabilize (layers[i-1].h, enabled=useStabilizer), inputDim=if i == 0 then inputDim else layers[i-1].dim,
                            previousHook=previousHook1,
                            augmentInputHook=if i == 0 then augmentInputHook1 else NoAuxInputHook, augmentInputDim=if i == 0 then augmentInputDim1 else 0,
                            layerIndex=i,
                            enableSelfStabilization=useStabilizer)
    ].layers

    # a stack of recurrent GRUs (bidirectional)
    # TODO: Should we define layerDims as the total (sum of both forward and backward direction)?
    RecurrentBidirectionalGRUStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
        previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization
        layers[i:0..Length (layerDims)-1] =
        [
            v    = if i == 0 then input    else Parameters.Stabilize (layers[i-1].h, enabled=useStabilizer)
            vDim = if i == 0 then inputDim else                       layers[i-1].dim
            fwd = RecurrentGRU   (layerDims[i], cellDim=cellDims[i],
                                  v, inputDim=vDim,
                                  previousHook=previousHook1,
                                  layerIndex=i,
                                  enableSelfStabilization=useStabilizer)
            bwd = RecurrentGRU   (layerDims[i], cellDim=cellDims[i],
                                  v, inputDim=vDim,
                                  previousHook=nextHook1,
                                  layerIndex=i,
                                  enableSelfStabilization=useStabilizer)
            h = Splice ((fwd.h : bwd.h), axis=1)
            c = Splice ((fwd.c : bwd.c), axis=1)
            dim = layerDims[i] * 2  # output dimension
        ]
    ].layers
]

##############################################################################
# sequence-to-sequence models
# This implements attention model and beam decoding.
##############################################################################

Seq2Seq =
[
    # attention model
    # The attention model is an additional input vector to the LSTM.
    # Here, it is implemented by augmenting this vector to the regular input of the LSTM.
    # The RecurrentLSTMP function does this inside through an optional lambda that the caller can pass in.
    # This function creates such a lambda, which augments the input vector from a fixed-size attention window.
    CreateAugmentWithFixedWindowAttentionHook (attentionDim, attentionSpan, decoderDynamicAxis, encoderOutput, enableSelfStabilization=false) =
    [
        # attention (fixed rolling window)
        attentionWindow = Sequences.PastValueWindow (attentionSpan, encoderOutput.h, axis=2) # BUGBUG: We should have this in axis=3 right away for beam search. Track this down.

        S(x) = Parameters.Stabilize (x, enabled=enableSelfStabilization)

        # project it for Tanh() expression
        # expected to be [attentionDim x 1 x attentionSpan], where that 1 is the axis of the beam in beam decoding
        projectedAttentionWindowBroadcast = [
            W = Parameters.WeightParam (attentionDim, encoderOutput.dim)
            # inject an additional singleton dimension at second axis, as a stand-in for the beam depth in decoding
            InjectBeamDepth (node) = SplitDimension (node, /*axis*/1, /*N:*/1)
           #projectedValue = Sequences.BroadcastSequenceAs (decoderDynamicAxis, InjectBeamDepth (W * attentionWindow.value)) # apply the projection columnwise to the attentionWindow tensor
            projectedValue = if enableSelfStabilization  # apply the projection columnwise to the attentionWindow tensor
                        then Sequences.BroadcastSequenceAs (decoderDynamicAxis, InjectBeamDepth (W * S(attentionWindow.value .* attentionWindow.valid))) # (mask invalid frames for stabilizer)
                        else Sequences.BroadcastSequenceAs (decoderDynamicAxis, InjectBeamDepth (W *   attentionWindow.value))
            value          = Sequences.BroadcastSequenceAs (decoderDynamicAxis, InjectBeamDepth (      attentionWindow.value))
            valid          = Sequences.BroadcastSequenceAs (decoderDynamicAxis, InjectBeamDepth (      attentionWindow.valid))
            dim            = encoderOutput.dim
        ]

        # the return value of this function is this lambda, which gets passed to the RecurrentLSTMP() function as the augmentInputHook parameter
        AugmentInputHook (input, prevState) =
        [
            # compute additional hidden state from attention
            outputDim = prevState.dim
            W = Parameters.WeightParam (attentionDim, outputDim)
            projectedH = W * S(prevState.h)                           # [outputDim] or [outputDim x D] in beam search
            tanHOut = Tanh (projectedAttentionWindowBroadcast.projectedValue + projectedH) # [attentionDim x beamDepth x attentionSpan]

            # You can enable (uncomment) these Trace macros to enable tracing of the attention weights, which is a useful indicator.
            TraceDense (h, what) = h
                //Trace (h, say=what, logFirst=10, logFrequency=100, logGradientToo=false, onlyUpToRow=9, onlyUpToT=25, format=[ type = "real" ; transpose = false ; precisionFormat = ".4" ])
            TraceDenseTransposed (h, what) = h
                //Trace (h, say=what, logFirst=10, logFrequency=100, logGradientToo=false, onlyUpToRow=9, onlyUpToT=25, format=[ type = "real" ; transpose = true ; precisionFormat = ".4" ])

            v = TraceDenseTransposed(    Parameters.WeightParam (1, attentionDim)     ,'v')                           # [1 x attentionDim]
            u = v * S(tanHOut .* projectedAttentionWindowBroadcast.valid) # [1 x beamDepth x attentionSpan]
            # ^^ mask 'v' for purpose of stabiliziation; TODO: don't do that if no stabiliziation
            uValid = u + Log (projectedAttentionWindowBroadcast.valid)    # [1 x beamDepth x attentionSpan]

            attentionWeights = Softmax (uValid, axis=3)                    # [1 x beamDepth x attentionSpan]
            weightedAttentionWindow = projectedAttentionWindowBroadcast.value .* TraceDense(  attentionWeights    ,'weights') # [encoderHiddenDim x beamDepth x attentionSpan]
            # TODO: use ReduceSum:
            # this is the auxiliary input to the LSTMP function
            weightedAttentionAverage = S(Times (weightedAttentionWindow, BS.Constants.OnesTensor (attentionSpan), outputRank=2)) # [encoderHiddenDim x beamDepth]
        ].weightedAttentionAverage
    ].AugmentInputHook

    # helper macro that extracts top D hypotheses from a 2D tensor
    # input: scores[w,n]    w = word index, d = hyp index in beam (d=0 is the best one)
    # output: [w,n1,n2]     n1 = input hyp index (prev top N); n2 = output hyp index (new top N)
    # e.g. 4 words, beam 3; view this as 3 [4x3] planes "drawn" 3-dimensionally, with depth being the 3rd tensor index
    GetTopNTensor (D, scores) = [
        # recurse over up to D elements
        # In each recursion:
        #  - pick the best over (w,n)
        #  - subtract it out from scores
        recursion[n:0..D-1] =
        [
            curBestScores = if n == 0                            # scores excluding paths better than rank n
                            then scores                          # top: just the path scores
                            else recursion[n - 1].nextBestScores # next: path scores after removing all we already got
            best = Hardmax (curBestScores)                       # best = one-hot over (w,n)
            nextBestScores = curBestScores + Constant (-1e30) .* best     # set the ones we've already got to -INF
            # TODO: use proper -INF; e.g. -1/0 in BS. Needs to be tested thoroughly.
        ]
        # splice them together into a single tensor
        asArray[n:0..D-1] = recursion[n].best  # this is a BS array consisting only of the 'best' field    ('from r in recursion select r.best')
        spliced = Splice (axis = 3, asArray)   # convert BS array index n to tensor index n1
    ].spliced

    # Create a greedy decoder model from an existing trained model.
    # The input model is expected to have these nodes:
    #  - decoderHistoryFromOutput: the decoding output of a time step (Hardmax (outputProbability))
    #  - decoderHistoryHook: a node that is the word sequence that will be used as the history for the next time step
    #    In training, this is the label sequence.
    #    In greedy decoding, it must be decoderHistoryHook = decoderHistoryFromOutput
    #  - z: scaled log prediction probability   --TODO: rename this: scoreSequence = Pass (z)
    #  - inputSequence
    #  - labelSequence (only passed through for scoring, not used in decoding)
    # The returned model has the following one-hot outputs:
    #  - decodedSequence  --TODO: currently decodeOut; rename this
    #  - inputSequence
    #  - labelSequence
    # To decode greedily, in "write" or "eval" specify the model as:
    #    BrainScriptNetworkBuilder = (BS.S2S.GreedySequenceDecoderFrom (BS.Network.Load ("$decodeModelPath$")))
    GreedySequenceDecoderFrom (modelAsTrained) = [
        scoreSequence = modelAsTrained.z
        decodeOut = Pass (      Hardmax (scoreSequence), tag='output')
        inputsOut = Pass (modelAsTrained.inputSequence,  tag='output')
        labelsOut = Pass (modelAsTrained.labelSequence,  tag='output')
        model = BS.Network.Edit (modelAsTrained,
                                 #BS.Network.Editing.ReplaceLinksToNode (modelAsTrained.decoderInput/*delayedDecoderFeedback*/, delayedDecoderFeedback),
                                 BS.Network.Editing.ReplaceLinksToNode (modelAsTrained.decoderHistoryHook, modelAsTrained.decoderHistoryFromOutput),
                                 decodeOut : inputsOut : labelsOut)
    ].model

    # turning a regular LSTM to a top-N beam-search decoder:
    #  - add a depth axis of dimension N to all nodes inside the decoder loop
    #     - only needs the init signal for PastValue to be that
    #  - h and c must be shuffled versions of their PastValue
    #     - since what are the top N in one time step is not the top N in the next
    #     - reshufling and adding depth to the init signal can be done at the same place
    #  - decoder output must determine the top N and a reshuffling matrix for h and c
    #     - the current Hardmax needs to be replaced by something that outputs these (output depth N)
    #     - we get a N^2 depth: [V x (input set) x (top N output hypos)]
    #     - reshuffling matrix is reduction over V (multiply with row of V ones) plus possibly a transposition
    #  - we need an accumulated path score
    #     - start value constructed by stacking a 0 and N-1 -INF
    #  - for testing, we can output the current best in each step
    #     - that's a Slice()
    #  - traceback is a right-to-left recurrence
    #     - output best hypo conditioned on the path (it is already known)
    # beam search of width 'beamDepth'
    BeamSearchSequenceDecoderFrom (modelAsTrained, beamDepth) = [

        scoreSequence = modelAsTrained.z
        vocabSize    = scoreSequence.dim

        # TODO: use ReduceSum
        ReduceAxis (axisDim, x, axis=1) =   # unfortunately, we must feed in the dimension of the axis, it can't be inferred
            if      axis == 1 then Times (Constants.OnesTensor (axisDim), x, outputRank=0)
            else if axis == 2 then ReduceAxis (axisDim, TransposeDimensions (x, 1, 2), axis=1)
            else Fail("ReduceAxis: Only supports axes 1 and 2.")

        # === BEGIN DECODER ===

        # constants for initial score and final traceback
        initialPathScores = FirstAndOther (0, LOGZERO, beamDepth, axis = 2)  # [1 x D]: [ 0, -INF, -INF, -INF, ... ]
        finalHyp          = FirstAndOther (1, 0,       beamDepth, axis = 1)  # [D] the final token is the top-scoring hypothesis, that is, hyp[0]

        # path expansion of the D hypotheses that were best in previous time step (ordered as in previous time step)
        logLLs = Columnwise (LogSoftmax, beamDepth, scoreSequence)                                      # [V x Dprev] log  P(w|hist)
        expandedPathScores = logLLs + Boolean.If (Loop.IsFirst (logLLs), initialPathScores, Loop.Previous (tokens.score)) # [V x Dprev] log (P(w|hist) * P(hist)) for all top D hypotheses

        # determine top D of expanded paths
        topPaths      = GetTopNTensor (beamDepth, expandedPathScores) # [V x Dprev] -> [V x Dprev x Dnew]
        topPathScores = topPaths .* expandedPathScores                #                [V x Dprev x Dnew]

        # form new decoding token, by reducing topPaths(Scores) along relevant dimensions
        tokens = [                                    # [. x Dnew]
            from  = ReduceAxis (axis=1, vocabSize, topPaths) # [Dprev x Dnew], reduced over V
            word  = ReduceAxis (axis=2, beamDepth, topPaths) # [V x Dnew], reduced over Dprev
            score = Constants.OnesTensor (1/*output dim*/ : /*reduction dims: */vocabSize : beamDepth/*Dprev*/) * topPathScores # [1 x Dnew], reduced over [V x Dprev] and inserted a '1'
        ]

        # network feedback for next time step
        # BUGBUG: Need to import EmbedLabels functionality from models
        decoderFeedback = /*EmbedLabels*/ (tokens.word) # [embeddingDim x Dnew]
        delayedDecoderFeedback = Boolean.If (Loop.IsFirst (labelSentenceStartEmbeddedScattered), labelSentenceStartEmbeddedScattered, Loop.Previous (decoderFeedback))

        # final traceback
        traceback = Boolean.If (Loop.IsLast (modelAsTrained.labelSentenceStartEmbeddedScattered/*tokens.from*/), finalHyp, Loop.Next (tokens.from * traceback)) # [D] one-hot, multiplying tokens.from from the left will select another one-hot row of tokens.from
        decodeHyp = Times (topPaths, traceback, outputRank=2) # [V x Dprev] 2D one-hot, selected the best hyp according to traceback
        decode = decodeHyp * Constants.OnesTensor (beamDepth) # [V] reduces over Dprev -> 1D one-hot
        # TODO: Can this be done in one ^^ go?

        # === END DECODER ===

        # propagate LSTM state to the right top-N rank given where that rank came from in the previous time step

        # PropagateTopN:
        # tokens.from: [Dprev, Dnew]
        #   v--------- best came from input hyp[1]
        #     v------- second best came from input hyp[0]
        #       v----- third best came from input hyp[2]
        #   0 1 0
        #   1 0 0
        #   0 0 1
        # tokens.from[:,n] one-hot encodes the best predecessor at top-N rank n
        # each column is a one-hot vector
        # multiplying with such a column from the right will select the column represented by the one-hot value

        # logLLs: get decoder log likelihoods

        # initialPathScores: decoder start token: 0 for first hyp, -INF for the others
        LOGZERO = -1e30

        # expandedPathScores: path expansion, [V x 1] + [1 x D] -> [V x D]

        # topPaths:
        #   +-----+
        #   |0 0 0|
        #   |0 0 0|-+
        #   |0 1 0|0|     means word[2] in input hyp[1] was the best
        #   |0 0 0|0|-+
        #   +-----+0|0|
        #     |1 0 0|0|   means word[3] in input hyp[0] was the second best
        #     +-----+1|   means word[2] in input hyp[2] was the third best
        #       |0 0 0|
        #       +-----+

        # tokens.word:
        #tokens.word = ReduceSum (axis=2, topPaths) # TODO: add an axis parameter to SumColumnElements()
        #   +-+
        #   |0|
        #   |0|-+
        #   |1|0|     means word[2] in input hyp[1] was the best
        #   |0|0|-+
        #   +-+0|0|
        #     |1|0|   means word[3] in input hyp[0] was the second best
        #     +-+1|   means word[2] in input hyp[2] was the third best
        #       |0|
        #       +-+

        # tokens.from:
        # before dropping the first dimension: [V x Dprev x Dnew]
        #   +-----+
        #   |0 1 0|       means input hyp[1] gave rise to the best
        #   +-----+-+
        #     |1 0 0|     means input hyp[0] gave rise to second best
        #     +-----+-+
        #       |0 0 1|   means input hyp[2] gave rise to third best
        #       +-----+
        # after: [Dprev x Dnew]        e.g. "0 1 0" goes into first column, vertically
        #   v--------- best came from input hyp[1]
        #     v------- second best came from input hyp[0]
        #       v----- third best came from input hyp[2]
        #   0 1 0
        #   1 0 0
        #   0 0 1
        # tokens.from[:,n] one-hot encodes the best predecessor at top-N rank n

        # topPathScores:
        #   +-----+
        #   |0 0 0|
        #   |0 0 0|-+
        #   |0 x 0|0|     x denotes the accumulated path score max_w P(w|hyp[1])
        #   |0 0 0|0|-+
        #   +-----+0|0|
        #     |y 0 0|0|   y denotes the accumulated path score max_w P(w|hyp[0])
        #     +-----+z|   z denotes the accumulated path score max_w P(w|hyp[2])
        #       |0 0 0|
        #       +-----+

        # traceback:
        # last state: take Hardmax over tokens.score
        # previous states: multiply wth respective tokens.from matrix
        # -> hyp index for every time step
        # then finally use that to select the actual output   TODO: That's a sample-wise matrix product between two sequences!!!
        # TODO: condition must be 1-dim, not 2-dim tensor, so we use labelSentenceStartEmbeddedScattered instead of tokens.from
        # +-+
        # |0|
        # |1|  means at this time step, hyp[1] was the best globally
        # |0|
        # +-+

        # decode: and the actual decoding output
        # This is the one to output (top sentence-level hypothesis after traceback).

        # traceback : [Dnew]
        # topPaths : [V x Dprev x Dnew]
        #   +-----+
        #   |0 0 0|
        #   |0 0 0|-+
        #   |0 1 0|0|     means word[2] in input hyp[1] was the best
        #   |0 0 0|0|-+
        #   +-----+0|0|
        #     |1 0 0|0|   means word[3] in input hyp[0] was the second best
        #     +-----+1|   means word[2] in input hyp[2] was the third best
        #       |0 0 0|
        #       +-----+

        # helper macros  --> move to BS.core.bs

        Columnwise (f, beamDepth, z) = # TODO: Takes LogSoftmax over axis=1. it is more tricky to do this over arbitrary axes
        [
            cols[d:0..beamDepth-1] = f (Slice (d, d+1, z, axis=2) /*[:,d]*/ )
            out = Splice (cols, axis=2)
        ].out

        FirstAndOther (firstVal, otherVals, N, axis = 1) = if N == 1 then ConstantTensor (firstVal, (1)) else [
            axis1 = axis  # TODO: Is this really necessary? Why? Then we need the syntax   axis = ^.axis or ^axis
            out = if axis == 1  # maybe this can be unified or pushed into Splice?
                  then RowStack (ConstantTensor (firstVal, (1)) : ConstantTensor (otherVals, (N -1)))                                # col vector: [ 1; 0; 0; 0 ... ]
                  else Splice   (Constant       (firstVal)      : ConstantTensor (otherVals, (1 : N -1)), axis = axis1 /*, axis*/)   # row vector: [ 0, -INF, -INF, -INF, ... ]
        ].out

        model = BS.Network.Edit (modelAsTrained,
                                 (
                                     BS.Network.Editing.ReplaceLinksToNode (modelAsTrained.beamSearchReorderHook, tokens.from) :   # reorder LSTM states
                                     BS.Network.Editing.ReplaceLinksToNode (modelAsTrained.decoderHistoryHook,    decoderFeedback) # feed decoder output back in
                                 ),
                                 (inputsOut : labelsOut : decodeOut)) # additional roots

        inputsOut = Pass (modelAsTrained.inputSequence, tag='output')
        labelsOut = Pass (modelAsTrained.labelSequence, tag='output')
        decodeOut = Pass (decode, tag='output')
    ].model
]

##############################################################################
# Network-level operations
# These operations will have undefined behavior for input values != 0 or 1.
##############################################################################

Network = [
    Load(pathName) = new ComputationNetworkFromFile [ /*pathName; also needs 'precision' somewhere*/ ]
    CloneFunction (inputNodes, outputNodes, parameters="learnable" /*|"constant"|"shared"*/) = new CloneFunctionConfigLambda [ /*args*/ ]
    Edit(inputModel, editFunctions, additionalRoots) = new ComputationNetworkWithEdits [ /*inputModel, editFunctions, additionalRoots*/ ]

    Editing = [
        // Create a lambda that returns its argument unless that argument == 'old', then it will return 'replacement'
        ReplaceLinksToNode (old, replacement) = (node => if IsSameObject (node, old) then replacement else node)
        ReplaceLinksToNamedNode (name, replacement) = (node => if node.name == name then replacement else node)
    ]
]

] # end of BS namespace
